Repository: cilinyan/VISA Branch: main Commit: ec158fde9cdf Files: 276 Total size: 54.3 MB Directory structure: gitextract_qarvn3ca/ ├── .gitignore ├── .gitmodules ├── README.md ├── XMem/ │ ├── dataset/ │ │ ├── __init__.py │ │ ├── range_transform.py │ │ ├── reseed.py │ │ ├── static_dataset.py │ │ ├── tps.py │ │ ├── util.py │ │ └── vos_dataset.py │ ├── eval.py │ ├── eval_batch.py │ ├── generate_xmem_data_single.py │ ├── inference/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── mask_mapper.py │ │ │ ├── test_datasets.py │ │ │ └── video_reader.py │ │ ├── inference_core.py │ │ ├── interact/ │ │ │ ├── __init__.py │ │ │ ├── fbrs/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── controller.py │ │ │ │ ├── inference/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── clicker.py │ │ │ │ │ ├── evaluation.py │ │ │ │ │ ├── predictors/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ ├── brs.py │ │ │ │ │ │ ├── brs_functors.py │ │ │ │ │ │ └── brs_losses.py │ │ │ │ │ ├── transforms/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ ├── crops.py │ │ │ │ │ │ ├── flip.py │ │ │ │ │ │ ├── limit_longest_side.py │ │ │ │ │ │ └── zoom_in.py │ │ │ │ │ └── utils.py │ │ │ │ ├── model/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── initializer.py │ │ │ │ │ ├── is_deeplab_model.py │ │ │ │ │ ├── is_hrnet_model.py │ │ │ │ │ ├── losses.py │ │ │ │ │ ├── metrics.py │ │ │ │ │ ├── modeling/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── basic_blocks.py │ │ │ │ │ │ ├── deeplab_v3.py │ │ │ │ │ │ ├── hrnet_ocr.py │ │ │ │ │ │ ├── ocr.py │ │ │ │ │ │ ├── resnet.py │ │ │ │ │ │ └── resnetv1b.py │ │ │ │ │ ├── ops.py │ │ │ │ │ └── syncbn/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── modules/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── functional/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── _csrc.py │ │ │ │ │ │ ├── csrc/ │ │ │ │ │ │ │ ├── bn.h │ │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ │ ├── bn_cuda.cu │ │ │ │ │ │ │ │ ├── common.h │ │ │ │ │ │ │ │ └── ext_lib.h │ │ │ │ │ │ │ └── ext_lib.cpp │ │ │ │ │ │ └── syncbn.py │ │ │ │ │ └── nn/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── syncbn.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cython/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── _get_dist_maps.pyx │ │ │ │ │ ├── _get_dist_maps.pyxbld │ │ │ │ │ └── dist_maps.py │ │ │ │ ├── misc.py │ │ │ │ └── vis.py │ │ │ ├── fbrs_controller.py │ │ │ ├── gui.py │ │ │ ├── gui_utils.py │ │ │ ├── interaction.py │ │ │ ├── interactive_utils.py │ │ │ ├── resource_manager.py │ │ │ ├── s2m/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _deeplab.py │ │ │ │ ├── s2m_network.py │ │ │ │ ├── s2m_resnet.py │ │ │ │ └── utils.py │ │ │ ├── s2m_controller.py │ │ │ └── timer.py │ │ ├── kv_memory_store.py │ │ └── memory_manager.py │ ├── interactive_demo.py │ ├── merge_multi_scale.py │ ├── merge_results.py │ ├── model/ │ │ ├── __init__.py │ │ ├── aggregate.py │ │ ├── cbam.py │ │ ├── group_modules.py │ │ ├── losses.py │ │ ├── memory_util.py │ │ ├── modules.py │ │ ├── network.py │ │ ├── resnet.py │ │ └── trainer.py │ ├── requirements.txt │ ├── scripts/ │ │ ├── __init__.py │ │ ├── download_bl30k.py │ │ ├── download_datasets.py │ │ ├── download_models.sh │ │ ├── download_models_demo.sh │ │ ├── expand_long_vid.py │ │ └── resize_youtube.py │ ├── tracking.py │ ├── train.py │ └── util/ │ ├── __init__.py │ ├── configuration.py │ ├── davis_subset.txt │ ├── image_saver.py │ ├── load_subset.py │ ├── log_integrator.py │ ├── logger.py │ ├── palette.py │ ├── tensor_util.py │ └── yv_subset.txt ├── merge_lora_weights_and_save_hf_model.py ├── model/ │ ├── VISA.py │ ├── llava/ │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── conversation.py │ │ ├── mm_utils.py │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── apply_delta.py │ │ │ ├── builder.py │ │ │ ├── consolidate.py │ │ │ ├── language_model/ │ │ │ │ ├── llava_llama.py │ │ │ │ ├── llava_mpt.py │ │ │ │ └── mpt/ │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ ├── attention.py │ │ │ │ ├── blocks.py │ │ │ │ ├── configuration_mpt.py │ │ │ │ ├── custom_embedding.py │ │ │ │ ├── flash_attn_triton.py │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ ├── meta_init_context.py │ │ │ │ ├── modeling_mpt.py │ │ │ │ ├── norm.py │ │ │ │ └── param_init_fns.py │ │ │ ├── llava_arch.py │ │ │ ├── make_delta.py │ │ │ ├── multimodal_encoder/ │ │ │ │ ├── builder.py │ │ │ │ └── clip_encoder.py │ │ │ └── utils.py │ │ ├── train/ │ │ │ ├── llama_flash_attn_monkey_patch.py │ │ │ ├── llava_trainer.py │ │ │ ├── train.py │ │ │ └── train_mem.py │ │ └── utils.py │ ├── segment_anything/ │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── modeling/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── image_encoder.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py │ ├── tf/ │ │ └── modeling_outputs.py │ └── univi/ │ ├── __init__.py │ ├── config/ │ │ ├── __init__.py │ │ ├── dataset_config.py │ │ └── model_config.py │ ├── constants.py │ ├── conversation.py │ ├── demo.py │ ├── eval/ │ │ ├── evaluate/ │ │ │ ├── evaluate_benchmark_1_correctness.py │ │ │ ├── evaluate_benchmark_2_detailed_orientation.py │ │ │ ├── evaluate_benchmark_3_context.py │ │ │ ├── evaluate_benchmark_4_temporal.py │ │ │ ├── evaluate_benchmark_5_consistency.py │ │ │ ├── evaluate_gpt_review_visual.py │ │ │ ├── evaluate_science_qa.py │ │ │ ├── evaluate_video_qa.py │ │ │ └── summarize_gpt_review.py │ │ ├── model_coco_vqa.py │ │ ├── model_video_consistency.py │ │ ├── model_video_general.py │ │ ├── model_video_qa.py │ │ ├── model_vqa.py │ │ ├── model_vqa_scienceqa.py │ │ ├── questions/ │ │ │ ├── coco2014_val_qa_eval/ │ │ │ │ ├── qa90_gpt4_answer.jsonl │ │ │ │ └── qa90_questions.jsonl │ │ │ ├── coco_pope/ │ │ │ │ ├── coco_pope_adversarial.jsonl │ │ │ │ ├── coco_pope_popular.jsonl │ │ │ │ └── coco_pope_random.jsonl │ │ │ ├── scienceqa/ │ │ │ │ ├── pid_splits.json │ │ │ │ ├── problems.json │ │ │ │ └── test_QCM-LEA.json │ │ │ └── video_qa/ │ │ │ ├── activitynet_a_list.json │ │ │ ├── activitynet_qa.json │ │ │ ├── consistency_qa.json │ │ │ ├── generic_qa.json │ │ │ ├── msrvtt_a_list.json │ │ │ ├── msrvtt_qa.json │ │ │ ├── msvd_a_list.json │ │ │ ├── msvd_qa.json │ │ │ ├── temporal_qa.json │ │ │ ├── tgif_a_list.json │ │ │ └── tgif_qa.json │ │ └── table/ │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── question.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ ├── mm_utils.py │ ├── model/ │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── arch.py │ │ ├── builder.py │ │ ├── cluster.py │ │ ├── consolidate.py │ │ ├── dataloader.py │ │ ├── language_model/ │ │ │ └── llama.py │ │ ├── make_delta.py │ │ └── multimodal_encoder/ │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── eva_encoder.py │ │ ├── eva_vit.py │ │ ├── processor.py │ │ └── utils.py │ ├── train/ │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── train.py │ │ ├── train_mem.py │ │ └── trainer.py │ └── utils.py ├── requirements.txt ├── scripts/ │ ├── train_13b.sh │ ├── train_7b.sh │ └── val_7b_video.sh ├── tools/ │ ├── eval_davis17.py │ ├── eval_mevis.py │ ├── eval_revos.py │ ├── generate_foreground_mask.py │ ├── metrics.py │ ├── zip_mp_mevis.py │ └── zip_mp_refytvos.py ├── train_ds.py ├── utils/ │ ├── ade20k_classes.json │ ├── chatunivi_dataset.py │ ├── cocostuff_classes.txt │ ├── conversation.py │ ├── d2_datasets/ │ │ ├── categories.py │ │ ├── mevis_utils.py │ │ ├── refytvos_utils.py │ │ ├── refytvos_val_videos.py │ │ └── ytvis_api/ │ │ ├── __init__.py │ │ ├── ytvos.py │ │ └── ytvoseval.py │ ├── data_processing.py │ ├── dataset.py │ ├── dataset_config.py │ ├── grefcoco.py │ ├── grefer.py │ ├── random_list.py │ ├── reason_seg_dataset.py │ ├── refer.py │ ├── refer_seg_dataset.py │ ├── rvos_dataset.py │ ├── rvos_eval_dataset.py │ ├── sem_seg_dataset.py │ ├── utils.py │ └── vqa_dataset.py └── utils_llamavid/ ├── llamavid_client.py └── llamavid_server.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ **/__pycache__ runs/ models/ datasets/ datasets .vscode/ core* vis_output/ test_vis/ openai .DS_Store XMem/weights/ ================================================ FILE: .gitmodules ================================================ [submodule "LLaMA-VID"] path = LLaMA-VID url = git@github.com:dvlab-research/LLaMA-VID.git ================================================ FILE: README.md ================================================ # VISA: Reasoning Video Object Segmentation via Large Language Model
[![ GitHub stars ](https://badgen.net/github/stars/cilinyan/VISA)](https://github.com/cilinyan/VISA) [![arXiv](https://img.shields.io/badge/arXiv-2407.11325-972B24)](http://arxiv.org/abs/2407.11325) [![Static Badge](https://img.shields.io/badge/dataset-ReVOS-3F7EBC)](https://github.com/cilinyan/ReVOS-api)
## 🚀 Performance
VISA demonstrates remarkable proficiency in handling complex segmentation tasks that require: (a) reasoning based on world knowledge; (b) inference of future events; and (c) a comprehensive understanding of video content.
## 🛠️ Installation ```shell pip install -r requirements.txt pip install flash-attn --no-build-isolation ``` ## 🦄 Training and Validation ### 1. Training Data Preparation Before training, please download the datasets, and then configure the path in [dataset_config.py](utils/dataset_config.py).
LISA's Dataset Follow [LISA](https://github.com/dvlab-research/LISA/tree/main) to prepare LISA's datasets. The dataset folder should be stored in the `$LISA_ROOT` folder. ``` LISA_ROOT ├── ade20k ├── coco ├── cocostuff ├── llava_dataset ├── mapillary ├── reason_seg ├── refer_seg └── vlpart ```
Chat-UniVi's Dataset Follow [Chat-UniVi/Chat-UniVi-Instruct](https://huggingface.co/datasets/Chat-UniVi/Chat-UniVi-Instruct/tree/main) to prepare `Chat-UniVi-Instruct` datasets. The dataset folder should be stored in the `$ChatUniVi_ROOT` folder. ``` ChatUniVi_ROOT ├── Fine-tuning │ ├── MIMIC_imageonly │ └── VIDEO └── ScienceQA_tuning ```
RVOS's Dataset 1. Reasoning Video Segmentation Datasets: [ReVOS](https://github.com/cilinyan/ReVOS-api). 2. Referring Video Segmentation Datasets: [Ref-Youtube-VOS](https://github.com/wjn922/ReferFormer/blob/main/docs/data.md), [Ref-DAVIS17](https://github.com/wjn922/ReferFormer/blob/main/docs/data.md), [MeViS](https://github.com/henghuiding/MeViS). - Ref-Youtube-VOS: Download `mask_dict.pkl` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/EqR9g3yWG5pPtVoil0EfsbgBJhCZ7YwaRG9w9GsYy1_N5g?e=JLaJfc) or [BaiduPan](https://pan.baidu.com/s/1mbJaDDy0UTlA7sysp0zypg?pwd=visa). - Ref-DAVIS17: Download `mask_dict.pkl` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/Eq8bmGqNcYxGhQ1bioN65q4B_gPxIabpJUjGaV5uqcaq3w?e=2J6Ldp) or [BaiduPan](https://pan.baidu.com/s/1Gg5qPvxRZMKDp0JrVRJ75w?pwd=visa). 3. Open-Vocabulary Video Instance Segmentation Dataset: [LV-VIS](https://github.com/haochenheheda/LVVIS/tree/main). Download `mask_dict.json` and `meta_expressions.json` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/EttXAjMV8yFJhHMQwX3mIw0BP7dymKV-cuw4uAotDaAwYw?e=j6Y44X) or [BaiduPan](https://pan.baidu.com/s/1LOWPnuxXF_LXGSL7osRptA?pwd=visa). Then, put the annotations files in the `$RVOS_ROOT/lvvis/train` directory as follows. ``` RVOS_ROOT ├── ReVOS │ ├── JPEGImages │ ├── mask_dict.json │ ├── mask_dict_foreground.json │ ├── meta_expressions_train_.json │ └── meta_expressions_valid_.json ├── lvvis │ └── train | ├── JPEGImages | ├── mask_dict.json | └── meta_expressions.json ├── Ref-Youtube-VOS │ ├── meta_expressions | | ├── train/meta_expressions.json | | └── valid/meta_expressions.json │ ├── train | | ├── JPEGImages | | └── mask_dict.pkl │ └── valid | └── JPEGImages ├── davis17 │ ├── meta_expressions | | ├── train/meta_expressions.json | | └── valid/meta_expressions.json │ ├── train | | ├── JPEGImages | | └── mask_dict.pkl │ └── valid | ├── JPEGImages | └── mask_dict.pkl └── mevis ```
### 2. Pre-trained weights
Chat-UniVi To train VISA-7B or 13B, you need to download Chat-UniVi weights from [Chat-UniVi-7B](https://huggingface.co/Chat-UniVi/Chat-UniVi) and [Chat-UniVi-13B](https://huggingface.co/Chat-UniVi/Chat-UniVi-13B).
SAM Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
### 3. Training VISA ```shell # Training VISA-7B bash scripts/train_7b.sh # Extracting fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. cd /PATH/TO/VISA-7B/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin # Merge LoRA Weight CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \ --version Chat-UniVi/Chat-UniVi \ --weight /PATH/TO/VISA-7B/pytorch_model.bin \ --save_path /PATH/TO/VISA-7B/hf_model ``` ### 4. Validation
1. Using `VISA` to generate predicted mask of each video [demo] ```shell deepspeed --master_port=24999 train_ds.py \ --version="/PATH/TO/VISA-7B/hf_model" \ --vision_pretrained="/PATH/TO/sam_vit_h_4b8939.pth" \ --log_base_dir="/PATH/TO/LOG_BASE_DIR" \ --exp_name="val_7b" \ --balance_sample \ --dataset="reason_seg" \ --sample_rates="13" \ --val_dataset "revos_valid" \ --eval_only ```
2. Using LLaMA-VID to generate target frame for each video > You can directly download the results of our run from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:u:/g/personal/yancl9918_mails_jlu_edu_cn/ETmoJF2i8ZZBsgIwdELiL8gBfptZZoPWjx6Y0eH6Myr3sw?e=mTt6rO) or [BaiduPan](https://pan.baidu.com/s/1YWs6NLPvANfhgUBHKQwnBg?pwd=visa) - Run [http_server_mp.py](https://github.com/cilinyan/VISA/blob/main/utils_llamavid/llamavid_server.py) to build the API server for LLaMA-VID [`[demo]`](https://github.com/cilinyan/VISA/blob/c53d2cd31407eab583c5eb04f84fd95b4694f2ce/utils_llamavid/llamavid_server.py#L215-L220) ```shell python utils_llamavid/llamavid_server.py \ --vision_tower /PATH/TO/eva_vit_g.pth \ --image_processor /PATH/TO/openai/clip-vit-large-patch14 \ --model-path /PATH/TO/YanweiLi/llama-vid-13b-full-224-video-fps-1 ``` - Using the API for inference [`[demo]`](https://github.com/cilinyan/VISA/blob/c53d2cd31407eab583c5eb04f84fd95b4694f2ce/utils_llamavid/llamavid_client.py#L58-L63) ```shell python utils_llamavid/llamavid_client.py \ --video_root /PATH/TO/ReVOS/JPEGImages \ --data_json_file /PATH/TO/ReVOS/meta_expressions_valid_.json ```
3. Using XMem for mask propagation [demo]
4. Evaluate ReVOS's performance [demo] ```shell cd tools python eval_revos.py /PATH/TO/FINAL_ANNOTATION [ARGS] ```
## 📑 Todo list - [x] Release code with `Text-guided Frame Sampler`'s Local Sampling - [ ] Release VISA model weights [issue #6](https://github.com/cilinyan/VISA/issues/6) - [ ] Release code with `Text-guided Frame Sampler`'s Global-Local Sampling ## ⭐ Cite If you find this project useful in your research, please consider citing: ``` @article{yan2024visa, title={VISA: Reasoning Video Object Segmentation via Large Language Models}, author={Yan, Cilin and Wang, Haochen and Yan, Shilin and Jiang, Xiaolong and Hu, Yao and Kang, Guoliang and Xie, Weidi and Gavves, Efstratios}, journal={arXiv preprint arXiv:2407.11325}, year={2024} } ``` ## 🎖️ Acknowledgement This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA), [SAM](https://github.com/facebookresearch/segment-anything), [LISA](https://github.com/dvlab-research/LISA), [Chat-UniVi](https://github.com/PKU-YuanGroup/Chat-UniVi), [MeViS](https://github.com/henghuiding/MeViS), [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID) and [XMem](https://github.com/hkchengrex/XMem). ================================================ FILE: XMem/dataset/__init__.py ================================================ ================================================ FILE: XMem/dataset/range_transform.py ================================================ import torchvision.transforms as transforms im_mean = (124, 116, 104) im_normalization = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) inv_im_trans = transforms.Normalize( mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) ================================================ FILE: XMem/dataset/reseed.py ================================================ import torch import random def reseed(seed): random.seed(seed) torch.manual_seed(seed) ================================================ FILE: XMem/dataset/static_dataset.py ================================================ import os from os import path import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode from PIL import Image import numpy as np from dataset.range_transform import im_normalization, im_mean from dataset.tps import random_tps_warp from dataset.reseed import reseed class StaticTransformDataset(Dataset): """ Generate pseudo VOS data by applying random transforms on static images. Single-object only. Method 0 - FSS style (class/1.jpg class/1.png) Method 1 - Others style (XXX.jpg XXX.png) """ def __init__(self, parameters, num_frames=3, max_num_obj=1): self.num_frames = num_frames self.max_num_obj = max_num_obj self.im_list = [] for parameter in parameters: root, method, multiplier = parameter if method == 0: # Get images classes = os.listdir(root) for c in classes: imgs = os.listdir(path.join(root, c)) jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] joint_list = [path.join(root, c, im) for im in jpg_list] self.im_list.extend(joint_list * multiplier) elif method == 1: self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier) print(f'{len(self.im_list)} images found.') # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames self.pair_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic ]) self.pair_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean), transforms.Resize(384, InterpolationMode.BICUBIC), transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean), ]) self.pair_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0), transforms.Resize(384, InterpolationMode.NEAREST), transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0), ]) # These transform are the same for all pairs in the sampled sequence self.all_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), transforms.RandomGrayscale(0.05), ]) self.all_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean), transforms.RandomHorizontalFlip(), ]) self.all_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0), transforms.RandomHorizontalFlip(), ]) # Final transform without randomness self.final_im_transform = transforms.Compose([ transforms.ToTensor(), im_normalization, ]) self.final_gt_transform = transforms.Compose([ transforms.ToTensor(), ]) def _get_sample(self, idx): im = Image.open(self.im_list[idx]).convert('RGB') gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') sequence_seed = np.random.randint(2147483647) images = [] masks = [] for _ in range(self.num_frames): reseed(sequence_seed) this_im = self.all_im_dual_transform(im) this_im = self.all_im_lone_transform(this_im) reseed(sequence_seed) this_gt = self.all_gt_dual_transform(gt) pairwise_seed = np.random.randint(2147483647) reseed(pairwise_seed) this_im = self.pair_im_dual_transform(this_im) this_im = self.pair_im_lone_transform(this_im) reseed(pairwise_seed) this_gt = self.pair_gt_dual_transform(this_gt) # Use TPS only some of the times # Not because TPS is bad -- just that it is too slow and I need to speed up data loading if np.random.rand() < 0.33: this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) this_im = self.final_im_transform(this_im) this_gt = self.final_gt_transform(this_gt) images.append(this_im) masks.append(this_gt) images = torch.stack(images, 0) masks = torch.stack(masks, 0) return images, masks.numpy() def __getitem__(self, idx): additional_objects = np.random.randint(self.max_num_obj) indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] merged_images = None merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64) for i, list_id in enumerate(indices): images, masks = self._get_sample(list_id) if merged_images is None: merged_images = images else: merged_images = merged_images*(1-masks) + images*masks merged_masks[masks[:,0]>0.5] = (i+1) masks = merged_masks labels = np.unique(masks[0]) # Remove background labels = labels[labels!=0] target_objects = labels.tolist() # Generate one-hot ground-truth cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64) first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64) for i, l in enumerate(target_objects): this_mask = (masks==l) cls_gt[this_mask] = i+1 first_frame_gt[0,i] = (this_mask[0]) cls_gt = np.expand_dims(cls_gt, 1) info = {} info['name'] = self.im_list[idx] info['num_objects'] = max(1, len(target_objects)) # 1 if object exist, 0 otherwise selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] selector = torch.FloatTensor(selector) data = { 'rgb': merged_images, 'first_frame_gt': first_frame_gt, 'cls_gt': cls_gt, 'selector': selector, 'info': info } return data def __len__(self): return len(self.im_list) ================================================ FILE: XMem/dataset/tps.py ================================================ import numpy as np from PIL import Image import cv2 import thinplate as tps cv2.setNumThreads(0) def pick_random_points(h, w, n_samples): y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False) x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False) return y_idx/h, x_idx/w def warp_dual_cv(img, mask, c_src, c_dst): dshape = img.shape theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) grid = tps.tps_grid(theta, c_dst, dshape) mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST) def random_tps_warp(img, mask, scale, n_ctrl_pts=12): """ Apply a random TPS warp of the input image and mask Uses randomness from numpy """ img = np.asarray(img) mask = np.asarray(mask) h, w = mask.shape points = pick_random_points(h, w, n_ctrl_pts) c_src = np.stack(points, 1) c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape) warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst) return Image.fromarray(warp_im), Image.fromarray(warp_gt) ================================================ FILE: XMem/dataset/util.py ================================================ import numpy as np def all_to_onehot(masks, labels): if len(masks.shape) == 3: Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) else: Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) for ni, l in enumerate(labels): Ms[ni] = (masks == l).astype(np.uint8) return Ms ================================================ FILE: XMem/dataset/vos_dataset.py ================================================ import os from os import path, replace import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode from PIL import Image import numpy as np from dataset.range_transform import im_normalization, im_mean from dataset.reseed import reseed class VOSDataset(Dataset): """ Works for DAVIS/YouTubeVOS/BL30K training For each sequence: - Pick three frames - Pick two objects - Apply some random transforms that are the same for all frames - Apply random transform to each of the frame - The distance between frames is controlled """ def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False): self.im_root = im_root self.gt_root = gt_root self.max_jump = max_jump self.is_bl = is_bl self.num_frames = num_frames self.max_num_obj = max_num_obj self.videos = [] self.frames = {} vid_list = sorted(os.listdir(self.im_root)) # Pre-filtering for vid in vid_list: if subset is not None: if vid not in subset: continue frames = sorted(os.listdir(os.path.join(self.im_root, vid))) if len(frames) < num_frames: continue self.frames[vid] = frames self.videos.append(vid) print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root)) # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames self.pair_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.01, 0.01, 0.01, 0), ]) self.pair_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean), ]) self.pair_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0), ]) # These transform are the same for all pairs in the sampled sequence self.all_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.03, 0.03, 0), transforms.RandomGrayscale(0.05), ]) if self.is_bl: # Use a different cropping scheme for the blender dataset because the image size is different self.all_im_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR) ]) self.all_gt_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST) ]) else: self.all_im_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR) ]) self.all_gt_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST) ]) # Final transform without randomness self.final_im_transform = transforms.Compose([ transforms.ToTensor(), im_normalization, ]) def __getitem__(self, idx): video = self.videos[idx] info = {} info['name'] = video vid_im_path = path.join(self.im_root, video) vid_gt_path = path.join(self.gt_root, video) frames = self.frames[video] trials = 0 while trials < 5: info['frames'] = [] # Appended with actual frames num_frames = self.num_frames length = len(frames) this_max_jump = min(len(frames), self.max_jump) # iterative sampling frames_idx = [np.random.randint(length)] acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx)) while(len(frames_idx) < num_frames): idx = np.random.choice(list(acceptable_set)) frames_idx.append(idx) new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))) acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx)) frames_idx = sorted(frames_idx) if np.random.rand() < 0.5: # Reverse time frames_idx = frames_idx[::-1] sequence_seed = np.random.randint(2147483647) images = [] masks = [] target_objects = [] for f_idx in frames_idx: jpg_name = frames[f_idx][:-4] + '.jpg' png_name = frames[f_idx][:-4] + '.png' info['frames'].append(jpg_name) reseed(sequence_seed) this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB') this_im = self.all_im_dual_transform(this_im) this_im = self.all_im_lone_transform(this_im) reseed(sequence_seed) this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P') this_gt = self.all_gt_dual_transform(this_gt) pairwise_seed = np.random.randint(2147483647) reseed(pairwise_seed) this_im = self.pair_im_dual_transform(this_im) this_im = self.pair_im_lone_transform(this_im) reseed(pairwise_seed) this_gt = self.pair_gt_dual_transform(this_gt) this_im = self.final_im_transform(this_im) this_gt = np.array(this_gt) images.append(this_im) masks.append(this_gt) images = torch.stack(images, 0) labels = np.unique(masks[0]) # Remove background labels = labels[labels!=0] if self.is_bl: # Find large enough labels good_lables = [] for l in labels: pixel_sum = (masks[0]==l).sum() if pixel_sum > 10*10: # OK if the object is always this small # Not OK if it is actually much bigger if pixel_sum > 30*30: good_lables.append(l) elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20: good_lables.append(l) labels = np.array(good_lables, dtype=np.uint8) if len(labels) == 0: target_objects = [] trials += 1 else: target_objects = labels.tolist() break if len(target_objects) > self.max_num_obj: target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False) info['num_objects'] = max(1, len(target_objects)) masks = np.stack(masks, 0) # Generate one-hot ground-truth cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64) first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64) for i, l in enumerate(target_objects): this_mask = (masks==l) cls_gt[this_mask] = i+1 first_frame_gt[0,i] = (this_mask[0]) cls_gt = np.expand_dims(cls_gt, 1) # 1 if object exist, 0 otherwise selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] selector = torch.FloatTensor(selector) data = { 'rgb': images, 'first_frame_gt': first_frame_gt, 'cls_gt': cls_gt, 'selector': selector, 'info': info, } return data def __len__(self): return len(self.videos) ================================================ FILE: XMem/eval.py ================================================ import os from os import path from argparse import ArgumentParser import shutil import torch import torch.nn.functional as F from torch.utils.data import DataLoader import numpy as np from PIL import Image from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset from inference.data.mask_mapper import MaskMapper from model.network import XMem from inference.inference_core import InferenceCore from tqdm import tqdm try: import hickle as hkl except ImportError: print('Failed to import hickle. Fine if not using multi-scale testing.') parser = ArgumentParser() parser.add_argument('--model', default='./saves/XMem.pth') parser.add_argument('--meta_exp', type=str) # Data options parser.add_argument('--d16_path', default='../DAVIS/2016') parser.add_argument('--d17_path', default='../DAVIS/2017') parser.add_argument('--y18_path', default='../YouTube2018') parser.add_argument('--y19_path', default='../YouTube') parser.add_argument('--lv_path', default='../long_video_set') # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations" parser.add_argument('--generic_path') parser.add_argument('--img_dir') parser.add_argument('--reversed', action='store_true') parser.add_argument('--split_part', type=int, default=0) parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17') parser.add_argument('--split', help='val/test', default='val') parser.add_argument('--output', default=None) parser.add_argument('--save_all', action='store_true', help='Save all frames. Useful only in YouTubeVOS/long-time video', ) parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking') # Long-term memory options parser.add_argument('--disable_long_term', action='store_true') parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', type=int, default=10000) parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) parser.add_argument('--top_k', type=int, default=30) parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5) parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) # Multi-scale options parser.add_argument('--save_scores', action='store_true') parser.add_argument('--flip', action='store_true') parser.add_argument('--size', default=480, type=int, help='Resize the shorter side to this size. -1 to use original resolution. ') args = parser.parse_args() config = vars(args) config['enable_long_term'] = not config['disable_long_term'] if args.output is None: args.output = f'../output/{args.dataset}_{args.split}' print(f'Output path not provided. Defaulting to {args.output}') """ Data preparation """ is_youtube = args.dataset.startswith('Y') is_davis = args.dataset.startswith('D') is_lv = args.dataset.startswith('LV') if is_youtube or args.save_scores: out_path = path.join(args.output, 'Annotations') else: out_path = args.output if is_youtube: if args.dataset == 'Y18': yv_path = args.y18_path elif args.dataset == 'Y19': yv_path = args.y19_path if args.split == 'val': args.split = 'valid' meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size) elif args.split == 'test': meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size) else: raise NotImplementedError elif is_davis: if args.dataset == 'D16': if args.split == 'val': # Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size) else: raise NotImplementedError palette = None elif args.dataset == 'D17': if args.split == 'val': meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size) elif args.split == 'test': meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size) else: raise NotImplementedError elif is_lv: if args.dataset == 'LV1': meta_dataset = LongTestDataset(args.meta_exp, path.join(args.lv_path, 'long_video')) elif args.dataset == 'LV3': meta_dataset = LongTestDataset(args.meta_exp, path.join(args.lv_path, 'long_video_x3')) else: raise NotImplementedError elif args.dataset == 'G': meta_dataset = LongTestDataset(args.meta_exp, path.join(args.generic_path), size=args.size, img_dir=args.img_dir, reversed_=args.reversed, split_part=args.split_part) if not args.save_all: args.save_all = True print('save_all is forced to be true in generic evaluation mode.') else: raise NotImplementedError torch.autograd.set_grad_enabled(False) # Set up loader meta_loader = meta_dataset.get_datasets() # Load our checkpoint network = XMem(config, args.model).cuda().eval() if args.model is not None: model_weights = torch.load(args.model) network.load_weights(model_weights, init_as_zero_if_needed=True) else: print('No model loaded.') total_process_time = 0 total_frames = 0 # Start eval for vid_reader in tqdm(meta_loader, total=len(meta_dataset)): loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2) vid_name = vid_reader.vid_name vid_length = len(loader) # no need to count usage for LT if the video is not that long anyway config['enable_long_term_count_usage'] = ( config['enable_long_term'] and (vid_length / (config['max_mid_term_frames']-config['min_mid_term_frames']) * config['num_prototypes']) >= config['max_long_term_elements'] ) mapper = MaskMapper() processor = InferenceCore(network, config=config) first_mask_loaded = False for ti, data in enumerate(loader): with torch.cuda.amp.autocast(enabled=not args.benchmark): rgb = data['rgb'].cuda()[0] msk = data.get('mask') info = data['info'] frame = info['frame'][0] shape = info['shape'] need_resize = info['need_resize'][0] """ For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964 Seems to be very similar in testing as my previous timing method with two cuda sync + time.time() in STCN though """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() if not first_mask_loaded: if msk is not None: first_mask_loaded = True else: # no point to do anything without a mask continue if args.flip: rgb = torch.flip(rgb, dims=[-1]) msk = torch.flip(msk, dims=[-1]) if msk is not None else None # Map possibly non-continuous labels to continuous ones if msk is not None: msk, labels = mapper.convert_mask(msk[0].numpy()) msk = torch.Tensor(msk).cuda() if need_resize: if msk.shape[0] == 0: print(vid_name) msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] processor.set_all_labels(list(mapper.remappings.values())) else: labels = None # Run the model on this frame prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1)) # Upsample to original size if needed if need_resize: prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] end.record() torch.cuda.synchronize() total_process_time += (start.elapsed_time(end)/1000) total_frames += 1 if args.flip: prob = torch.flip(prob, dims=[-1]) # Probability mask -> index mask out_mask = torch.max(prob, dim=0).indices out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) if args.save_scores: prob = (prob.detach().cpu().numpy()*255).astype(np.uint8) # Save the mask if args.save_all or info['save'][0]: this_out_path = path.join(out_path, vid_name) os.makedirs(this_out_path, exist_ok=True) out_mask = mapper.remap_index_mask(out_mask) out_img = Image.fromarray(out_mask) if vid_reader.get_palette() is not None: out_img.putpalette(vid_reader.get_palette()) out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) if args.save_scores: np_path = path.join(args.output, 'Scores', vid_name) os.makedirs(np_path, exist_ok=True) if ti==len(loader)-1: hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w') if args.save_all or info['save'][0]: hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') print(f'Total processing time: {total_process_time}') print(f'Total processed frames: {total_frames}') print(f'FPS: {total_frames / total_process_time}') print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') if not args.save_scores: if is_youtube: print('Making zip for YouTubeVOS...') shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations') elif is_davis and args.split == 'test': print('Making zip for DAVIS test-dev...') shutil.make_archive(args.output, 'zip', args.output) ================================================ FILE: XMem/eval_batch.py ================================================ import os import time import torch import argparse import multiprocessing as mp from termcolor import colored from datetime import datetime from importlib.util import find_spec if find_spec("GPUtil") is None: os.system("pip install gputil") import GPUtil _GPU_LIST = [_.id for _ in GPUtil.getGPUs()] _GPU_QUEUE = mp.Queue() for _ in _GPU_LIST: _GPU_QUEUE.put(_) def run_eval(meta_expression, temp_xmem_anno, final_xmem_anno, img_dir, split_part, cfgs=" --reversed "): gpu_id = _GPU_QUEUE.get() cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python eval.py --meta_exp {meta_expression} --output {final_xmem_anno} --generic_path {temp_xmem_anno} --img_dir {img_dir} --split_part {split_part} --dataset G {cfgs}" print(f"Running: {cmd}") os.system(cmd) _GPU_QUEUE.put(gpu_id) def main(): parser = argparse.ArgumentParser() parser.add_argument("--meta_expression", type=str, help='/PATH/TO/ReVOS/meta_expressions_valid__llamavid.json') parser.add_argument("--temp_xmem_anno", type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations') parser.add_argument("--final_xmem_anno", type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations') parser.add_argument("--img_dir", type=str, help='/PATH/TO/ReVOS/JPEGImages') args = parser.parse_args() p = mp.Pool(8) for split_part in [0, 1, 2, 3]: for cfgs in [" ", " --reversed "]: p.apply_async( run_eval, args=(args.meta_expression, args.temp_xmem_anno, args.final_xmem_anno, args.img_dir, split_part, cfgs), error_callback=lambda e: print(colored(e, "red")) ) p.close() p.join() if __name__ == "__main__": main() ================================================ FILE: XMem/generate_xmem_data_single.py ================================================ import sys import os import os.path as osp import glob import cv2 import multiprocessing import json import argparse from tqdm import tqdm from termcolor import colored """ python generate_xmem_data_single.py \ --video_root /PATH/TO/VISA_exp/revos_valid/Annotations \ --output_dir /PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations \ --final_xmem_anno /PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations \ --llama_vid_meta /PATH/TO/ReVOS/meta_expressions_valid__llamavid.json """ def generate(obj, temp_xmem_anno, final_xmem_anno): obj_dir, video_name, obj_id, tp = obj img_list = glob.glob(obj_dir + '/*.png') # Mask img_list.sort() frame_id = int(len(img_list) * tp) if frame_id == len(img_list): frame_id -= 1 used_img = img_list[frame_id] img_output_path = osp.join(temp_xmem_anno, video_name, obj_id, osp.basename(used_img)) final_img_output_dir = osp.join(final_xmem_anno, video_name, obj_id) img_output_dir = osp.dirname(img_output_path) os.makedirs(img_output_dir, exist_ok=True) os.makedirs(final_img_output_dir, exist_ok=True) os.system('cp {} {}'.format(used_img, img_output_path)) img = cv2.imread(img_output_path) if img.sum() == 0: target_img_list = [i.split('/')[-1] for i in img_list] for img_ in target_img_list: print(os.path.join(final_img_output_dir, img_)) os.system('cp {} {}'.format(img_output_path, os.path.join(img_output_dir, img_))) os.system('cp {} {}'.format(img_output_path, os.path.join(final_img_output_dir, img_))) return 0 def main(): parser = argparse.ArgumentParser(description='rgvos') parser.add_argument('--video_root', type=str, help='/PATH/TO/VISA_exp/revos_valid/Annotations', ) parser.add_argument('--temp_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations', ) # 保存单帧 Mask 的路径 parser.add_argument('--final_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations', ) # 保存 XMem 最后输出结果的路径 parser.add_argument("--llama_vid_meta", type=str, help='/PATH/TO/ReVOS/meta_expressions_valid__llamavid.json', ) args = parser.parse_args() video_root = args.video_root temp_xmem_anno = args.temp_xmem_anno final_xmem_anno = args.final_xmem_anno os.makedirs(temp_xmem_anno, exist_ok=True) data = json.load(open(args.llama_vid_meta, 'r'))['videos'] all_obj_list = [] for video_name in data.keys(): exps = data[video_name]['expressions'] for obj_id in exps.keys(): tp = exps[obj_id]['tp'] obj_dir = os.path.join(video_root, video_name, obj_id) all_obj_list.append([obj_dir, video_name, obj_id, tp]) print('start') cpu_num = multiprocessing.cpu_count()-1 print("cpu_num:", cpu_num) pool = multiprocessing.Pool(cpu_num) pbar = tqdm(total=len(all_obj_list)) for obj in all_obj_list: pool.apply_async( generate, args = (obj, temp_xmem_anno, final_xmem_anno ), callback = lambda *a: pbar.update(1), error_callback = lambda e: print(colored(e, "red")) ) pool.close() pool.join() pbar.close() if __name__ == '__main__': main() ================================================ FILE: XMem/inference/__init__.py ================================================ ================================================ FILE: XMem/inference/data/__init__.py ================================================ ================================================ FILE: XMem/inference/data/mask_mapper.py ================================================ import numpy as np import torch from dataset.util import all_to_onehot class MaskMapper: """ This class is used to convert a indexed-mask to a one-hot representation. It also takes care of remapping non-continuous indices It has two modes: 1. Default. Only masks with new indices are supposed to go into the remapper. This is also the case for YouTubeVOS. i.e., regions with index 0 are not "background", but "don't care". 2. Exhaustive. Regions with index 0 are considered "background". Every single pixel is considered to be "labeled". """ def __init__(self): self.labels = [] self.remappings = {} # if coherent, no mapping is required self.coherent = True def convert_mask(self, mask, exhaustive=False): # mask is in index representation, H*W numpy array labels = np.unique(mask).astype(np.uint8) labels = labels[labels!=0].tolist() new_labels = list(set(labels) - set(self.labels)) if not exhaustive: assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' # add new remappings for i, l in enumerate(new_labels): self.remappings[l] = i+len(self.labels)+1 if self.coherent and i+len(self.labels)+1 != l: self.coherent = False if exhaustive: new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) else: if self.coherent: new_mapped_labels = new_labels else: new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) self.labels.extend(new_labels) mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() # mask num_objects*H*W return mask, new_mapped_labels def remap_index_mask(self, mask): # mask is in index representation, H*W numpy array if self.coherent: return mask new_mask = np.zeros_like(mask) for l, i in self.remappings.items(): new_mask[mask==i] = l return new_mask ================================================ FILE: XMem/inference/data/test_datasets.py ================================================ import os from os import path import json import glob from inference.data.video_reader import VideoReader class LongTestDataset: def __init__(self, meta_expression, data_root, size=-1, img_dir = '', reversed_ = False, split_part = 0): self.image_dir = img_dir self.mask_dir = data_root self.size = size self.reversed = reversed_ self.split_part = split_part self.vid_list = [] videos_names = json.load(open(meta_expression, 'r'))['videos'] for video_name in videos_names: video_mask_dir = path.join(self.mask_dir, video_name) obj_ids = [d for d in os.listdir(video_mask_dir) if os.path.isdir(path.join(video_mask_dir, d))] for obj_id in obj_ids: obj_dir = path.join(video_mask_dir, obj_id) img_list = glob.glob(obj_dir + '/*') if len(img_list) == 1: self.vid_list.append(path.join(video_name, obj_id)) self.vid_list.sort() self.vid_list = [i for idx, i in enumerate(self.vid_list) if idx % 4 == self.split_part] def get_datasets(self): for video in self.vid_list: yield VideoReader(video, path.join(self.image_dir, '/'.join(video.split('/')[:-1])), path.join(self.mask_dir, video), to_save = [ name[:-4] for name in os.listdir(path.join(self.mask_dir, video)) # remove .png ], size=self.size, reversed=self.reversed, ) def __len__(self): return len(self.vid_list) class DAVISTestDataset: def __init__(self, data_root, imset='2017/val.txt', size=-1): if size != 480: self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution') self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution') if not path.exists(self.image_dir): print(f'{self.image_dir} not found. Look at other options.') self.image_dir = path.join(data_root, 'JPEGImages', '1080p') self.mask_dir = path.join(data_root, 'Annotations', '1080p') assert path.exists(self.image_dir), 'path not found' else: self.image_dir = path.join(data_root, 'JPEGImages', '480p') self.mask_dir = path.join(data_root, 'Annotations', '480p') self.size_dir = path.join(data_root, 'JPEGImages', '480p') self.size = size with open(path.join(data_root, 'ImageSets', imset)) as f: self.vid_list = sorted([line.strip() for line in f]) def get_datasets(self): for video in self.vid_list: yield VideoReader(video, path.join(self.image_dir, video), path.join(self.mask_dir, video), size=self.size, size_dir=path.join(self.size_dir, video), ) def __len__(self): return len(self.vid_list) class YouTubeVOSTestDataset: def __init__(self, data_root, split, size=480): self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages') self.mask_dir = path.join(data_root, split, 'Annotations') self.size = size self.vid_list = sorted(os.listdir(self.image_dir)) self.req_frame_list = {} with open(path.join(data_root, split, 'meta.json')) as f: # read meta.json to know which frame is required for evaluation meta = json.load(f)['videos'] for vid in self.vid_list: req_frames = [] objects = meta[vid]['objects'] for value in objects.values(): req_frames.extend(value['frames']) req_frames = list(set(req_frames)) self.req_frame_list[vid] = req_frames def get_datasets(self): for video in self.vid_list: yield VideoReader(video, path.join(self.image_dir, video), path.join(self.mask_dir, video), size=self.size, to_save=self.req_frame_list[video], use_all_mask=True ) def __len__(self): return len(self.vid_list) ================================================ FILE: XMem/inference/data/video_reader.py ================================================ import os from os import path from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode import torch.nn.functional as F from PIL import Image import numpy as np from dataset.range_transform import im_normalization class VideoReader(Dataset): """ This class is used to read a video, one frame at a time """ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, reversed = False): """ image_dir - points to a directory of jpg images mask_dir - points to a directory of png masks size - resize min. side to size. Does nothing if <0. to_save - optionally contains a list of file names without extensions where the segmentation mask is required use_all_mask - when true, read all available mask in mask_dir. Default false. Set to true for YouTubeVOS validation. """ self.vid_name = vid_name self.image_dir = image_dir self.mask_dir = mask_dir self.to_save = to_save self.use_all_mask = use_all_mask self.reversed = reversed if size_dir is None: self.size_dir = self.image_dir else: self.size_dir = size_dir self.frames = sorted(os.listdir(self.image_dir)) if self.reversed: self.frames = self.frames[::-1] self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) if size < 0: self.im_transform = transforms.Compose([ transforms.ToTensor(), im_normalization, ]) else: self.im_transform = transforms.Compose([ transforms.ToTensor(), im_normalization, transforms.Resize(size, interpolation=InterpolationMode.BILINEAR), ]) self.size = size def __getitem__(self, idx): frame = self.frames[idx] info = {} data = {} info['frame'] = frame info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) im_path = path.join(self.image_dir, frame) img = Image.open(im_path).convert('RGB') if self.image_dir == self.size_dir: shape = np.array(img).shape[:2] else: size_path = path.join(self.size_dir, frame) size_im = Image.open(size_path).convert('RGB') shape = np.array(size_im).shape[:2] gt_path = path.join(self.mask_dir, frame[:-4]+'.png') img = self.im_transform(img) load_mask = self.use_all_mask or (gt_path == self.first_gt_path) if load_mask and path.exists(gt_path): mask = Image.open(gt_path).convert('P') mask = np.array(mask, dtype=np.uint8) data['mask'] = mask info['shape'] = shape info['need_resize'] = not (self.size < 0) data['rgb'] = img data['info'] = info return data def resize_mask(self, mask): # mask transform is applied AFTER mapper, so we need to post-process it in eval.py h, w = mask.shape[-2:] min_hw = min(h, w) return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), mode='nearest') def get_palette(self): return self.palette def __len__(self): return len(self.frames) ================================================ FILE: XMem/inference/inference_core.py ================================================ from inference.memory_manager import MemoryManager from model.network import XMem from model.aggregate import aggregate from util.tensor_util import pad_divide_by, unpad class InferenceCore: def __init__(self, network:XMem, config): self.config = config self.network = network self.mem_every = config['mem_every'] self.deep_update_every = config['deep_update_every'] self.enable_long_term = config['enable_long_term'] # if deep_update_every < 0, synchronize deep update with memory frame self.deep_update_sync = (self.deep_update_every < 0) self.clear_memory() self.all_labels = None def clear_memory(self): self.curr_ti = -1 self.last_mem_ti = 0 if not self.deep_update_sync: self.last_deep_update_ti = -self.deep_update_every self.memory = MemoryManager(config=self.config) def update_config(self, config): self.mem_every = config['mem_every'] self.deep_update_every = config['deep_update_every'] self.enable_long_term = config['enable_long_term'] # if deep_update_every < 0, synchronize deep update with memory frame self.deep_update_sync = (self.deep_update_every < 0) self.memory.update_config(config) def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels def step(self, image, mask=None, valid_labels=None, end=False): # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) is_deep_update = ( (self.deep_update_sync and is_mem_frame) or # synchronized (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync ) and (not end) is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end) key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame) multi_scale_features = (f16, f8, f4) # segment the current frame is needed if need_segment: memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) # remove batch dim pred_prob_with_bg = pred_prob_with_bg[0] pred_prob_no_bg = pred_prob_with_bg[1:] if is_normal_update: self.memory.set_hidden(hidden) else: pred_prob_no_bg = pred_prob_with_bg = None # use the input mask if any if mask is not None: mask, _ = pad_divide_by(mask, 16) if pred_prob_no_bg is not None: # if we have a predicted mask, we work on it # make pred_prob_no_bg consistent with the input mask mask_regions = (mask.sum(0) > 0.5) pred_prob_no_bg[:, mask_regions] = 0 # shift by 1 because mask/pred_prob_no_bg do not contain background mask = mask.type_as(pred_prob_no_bg) if valid_labels is not None: shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels] # non-labelled objects are copied from the predicted mask mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels] pred_prob_with_bg = aggregate(mask, dim=0) # also create new hidden states self.memory.create_hidden_state(len(self.all_labels), key) # save as memory if needed if is_mem_frame: value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) self.memory.add_memory(key, shrinkage, value, self.all_labels, selection=selection if self.enable_long_term else None) self.last_mem_ti = self.curr_ti if is_deep_update: self.memory.set_hidden(hidden) self.last_deep_update_ti = self.curr_ti return unpad(pred_prob_with_bg, self.pad) ================================================ FILE: XMem/inference/interact/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/LICENSE ================================================ Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. ================================================ FILE: XMem/inference/interact/fbrs/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/controller.py ================================================ import torch try: from torch import mps except: pass from ..fbrs.inference import clicker from ..fbrs.inference.predictors import get_predictor class InteractiveController: def __init__(self, net, device, predictor_params, prob_thresh=0.5): self.net = net.to(device) self.prob_thresh = prob_thresh self.clicker = clicker.Clicker() self.states = [] self.probs_history = [] self.object_count = 0 self._result_mask = None self.image = None self.predictor = None self.device = device self.predictor_params = predictor_params self.reset_predictor() def set_image(self, image): self.image = image self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8) self.object_count = 0 self.reset_last_object() def add_click(self, x, y, is_positive): self.states.append({ 'clicker': self.clicker.get_state(), 'predictor': self.predictor.get_states() }) click = clicker.Click(is_positive=is_positive, coords=(y, x)) self.clicker.add_click(click) pred = self.predictor.get_prediction(self.clicker) if self.device.type == 'cuda': torch.cuda.empty_cache() elif self.device.type == 'mps': mps.empty_cache() if self.probs_history: self.probs_history.append((self.probs_history[-1][0], pred)) else: self.probs_history.append((torch.zeros_like(pred), pred)) def undo_click(self): if not self.states: return prev_state = self.states.pop() self.clicker.set_state(prev_state['clicker']) self.predictor.set_states(prev_state['predictor']) self.probs_history.pop() def partially_finish_object(self): object_prob = self.current_object_prob if object_prob is None: return self.probs_history.append((object_prob, torch.zeros_like(object_prob))) self.states.append(self.states[-1]) self.clicker.reset_clicks() self.reset_predictor() def finish_object(self): object_prob = self.current_object_prob if object_prob is None: return self.object_count += 1 object_mask = object_prob > self.prob_thresh self._result_mask[object_mask] = self.object_count self.reset_last_object() def reset_last_object(self): self.states = [] self.probs_history = [] self.clicker.reset_clicks() self.reset_predictor() def reset_predictor(self, predictor_params=None): if predictor_params is not None: self.predictor_params = predictor_params self.predictor = get_predictor(self.net, device=self.device, **self.predictor_params) if self.image is not None: self.predictor.set_input_image(self.image) @property def current_object_prob(self): if self.probs_history: current_prob_total, current_prob_additive = self.probs_history[-1] return torch.maximum(current_prob_total, current_prob_additive) else: return None @property def is_incomplete_mask(self): return len(self.probs_history) > 0 @property def result_mask(self): return self._result_mask.clone() ================================================ FILE: XMem/inference/interact/fbrs/inference/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/inference/clicker.py ================================================ from collections import namedtuple import numpy as np from copy import deepcopy from scipy.ndimage import distance_transform_edt Click = namedtuple('Click', ['is_positive', 'coords']) class Clicker(object): def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1): if gt_mask is not None: self.gt_mask = gt_mask == 1 self.not_ignore_mask = gt_mask != ignore_label else: self.gt_mask = None self.reset_clicks() if init_clicks is not None: for click in init_clicks: self.add_click(click) def make_next_click(self, pred_mask): assert self.gt_mask is not None click = self._get_click(pred_mask) self.add_click(click) def get_clicks(self, clicks_limit=None): return self.clicks_list[:clicks_limit] def _get_click(self, pred_mask, padding=True): fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) if padding: fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') fn_mask_dt = distance_transform_edt(fn_mask) fp_mask_dt = distance_transform_edt(fp_mask) if padding: fn_mask_dt = fn_mask_dt[1:-1, 1:-1] fp_mask_dt = fp_mask_dt[1:-1, 1:-1] fn_mask_dt = fn_mask_dt * self.not_clicked_map fp_mask_dt = fp_mask_dt * self.not_clicked_map fn_max_dist = np.max(fn_mask_dt) fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist if is_positive: coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] else: coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) def add_click(self, click): coords = click.coords if click.is_positive: self.num_pos_clicks += 1 else: self.num_neg_clicks += 1 self.clicks_list.append(click) if self.gt_mask is not None: self.not_clicked_map[coords[0], coords[1]] = False def _remove_last_click(self): click = self.clicks_list.pop() coords = click.coords if click.is_positive: self.num_pos_clicks -= 1 else: self.num_neg_clicks -= 1 if self.gt_mask is not None: self.not_clicked_map[coords[0], coords[1]] = True def reset_clicks(self): if self.gt_mask is not None: self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) self.num_pos_clicks = 0 self.num_neg_clicks = 0 self.clicks_list = [] def get_state(self): return deepcopy(self.clicks_list) def set_state(self, state): self.reset_clicks() for click in state: self.add_click(click) def __len__(self): return len(self.clicks_list) ================================================ FILE: XMem/inference/interact/fbrs/inference/evaluation.py ================================================ from time import time import numpy as np import torch from ..inference import utils from ..inference.clicker import Clicker try: get_ipython() from tqdm import tqdm_notebook as tqdm except NameError: from tqdm import tqdm def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs): all_ious = [] start_time = time() for index in tqdm(range(len(dataset)), leave=False): sample = dataset.get_sample(index) item = dataset[index] if oracle_eval: gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32) gt_mask = gt_mask.unsqueeze(0).unsqueeze(0) predictor.opt_functor.mask_loss.set_gt_mask(gt_mask) _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs) all_ious.append(sample_ious) end_time = time() elapsed_time = end_time - start_time return all_ious, elapsed_time def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr, pred_thr=0.49, max_clicks=20): clicker = Clicker(gt_mask=instances_mask) pred_mask = np.zeros_like(instances_mask) ious_list = [] with torch.no_grad(): predictor.set_input_image(image_nd) for click_number in range(max_clicks): clicker.make_next_click(pred_mask) pred_probs = predictor.get_prediction(clicker) pred_mask = pred_probs > pred_thr iou = utils.get_iou(instances_mask, pred_mask) ious_list.append(iou) if iou >= max_iou_thr: break return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs ================================================ FILE: XMem/inference/interact/fbrs/inference/predictors/__init__.py ================================================ from .base import BasePredictor from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor from .brs_functors import InputOptimizer, ScaleBiasOptimizer from ..transforms import ZoomIn from ...model.is_hrnet_model import DistMapsHRNetModel def get_predictor(net, brs_mode, device, prob_thresh=0.49, with_flip=True, zoom_in_params=dict(), predictor_params=None, brs_opt_func_params=None, lbfgs_params=None): lbfgs_params_ = { 'm': 20, 'factr': 0, 'pgtol': 1e-8, 'maxfun': 20, } predictor_params_ = { 'optimize_after_n_clicks': 1 } if zoom_in_params is not None: zoom_in = ZoomIn(**zoom_in_params) else: zoom_in = None if lbfgs_params is not None: lbfgs_params_.update(lbfgs_params) lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] if brs_opt_func_params is None: brs_opt_func_params = dict() if brs_mode == 'NoBRS': if predictor_params is not None: predictor_params_.update(predictor_params) predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) elif brs_mode.startswith('f-BRS'): predictor_params_.update({ 'net_clicks_limit': 8, }) if predictor_params is not None: predictor_params_.update(predictor_params) insertion_mode = { 'f-BRS-A': 'after_c4', 'f-BRS-B': 'after_aspp', 'f-BRS-C': 'after_deeplab' }[brs_mode] opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, with_flip=with_flip, optimizer_params=lbfgs_params_, **brs_opt_func_params) if isinstance(net, DistMapsHRNetModel): FeaturePredictor = HRNetFeatureBRSPredictor insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] else: FeaturePredictor = FeatureBRSPredictor predictor = FeaturePredictor(net, device, opt_functor=opt_functor, with_flip=with_flip, insertion_mode=insertion_mode, zoom_in=zoom_in, **predictor_params_) elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': use_dmaps = brs_mode == 'DistMap-BRS' predictor_params_.update({ 'net_clicks_limit': 5, }) if predictor_params is not None: predictor_params_.update(predictor_params) opt_functor = InputOptimizer(prob_thresh=prob_thresh, with_flip=with_flip, optimizer_params=lbfgs_params_, **brs_opt_func_params) predictor = InputBRSPredictor(net, device, optimize_target='dmaps' if use_dmaps else 'rgb', opt_functor=opt_functor, with_flip=with_flip, zoom_in=zoom_in, **predictor_params_) else: raise NotImplementedError return predictor ================================================ FILE: XMem/inference/interact/fbrs/inference/predictors/base.py ================================================ import torch import torch.nn.functional as F from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide class BasePredictor(object): def __init__(self, net, device, net_clicks_limit=None, with_flip=False, zoom_in=None, max_size=None, **kwargs): self.net = net self.with_flip = with_flip self.net_clicks_limit = net_clicks_limit self.original_image = None self.device = device self.zoom_in = zoom_in self.transforms = [zoom_in] if zoom_in is not None else [] if max_size is not None: self.transforms.append(LimitLongestSide(max_size=max_size)) self.transforms.append(SigmoidForPred()) if with_flip: self.transforms.append(AddHorizontalFlip()) def set_input_image(self, image_nd): for transform in self.transforms: transform.reset() self.original_image = image_nd.to(self.device) if len(self.original_image.shape) == 3: self.original_image = self.original_image.unsqueeze(0) def get_prediction(self, clicker): clicks_list = clicker.get_clicks() image_nd, clicks_lists, is_image_changed = self.apply_transforms( self.original_image, [clicks_list] ) pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, size=image_nd.size()[2:]) for t in reversed(self.transforms): prediction = t.inv_transform(prediction) if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): print('zooming') return self.get_prediction(clicker) # return prediction.cpu().numpy()[0, 0] return prediction def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) return self.net(image_nd, points_nd)['instances'] def _get_transform_states(self): return [x.get_state() for x in self.transforms] def _set_transform_states(self, states): assert len(states) == len(self.transforms) for state, transform in zip(states, self.transforms): transform.set_state(state) def apply_transforms(self, image_nd, clicks_lists): is_image_changed = False for t in self.transforms: image_nd, clicks_lists = t.transform(image_nd, clicks_lists) is_image_changed |= t.image_changed return image_nd, clicks_lists, is_image_changed def get_points_nd(self, clicks_lists): total_clicks = [] num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] num_max_points = max(num_pos_clicks + num_neg_clicks) if self.net_clicks_limit is not None: num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: clicks_list = clicks_list[:self.net_clicks_limit] pos_clicks = [click.coords for click in clicks_list if click.is_positive] pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)] neg_clicks = [click.coords for click in clicks_list if not click.is_positive] neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)] total_clicks.append(pos_clicks + neg_clicks) return torch.tensor(total_clicks, device=self.device) def get_states(self): return {'transform_states': self._get_transform_states()} def set_states(self, states): self._set_transform_states(states['transform_states']) ================================================ FILE: XMem/inference/interact/fbrs/inference/predictors/brs.py ================================================ import torch import torch.nn.functional as F import numpy as np from scipy.optimize import fmin_l_bfgs_b from .base import BasePredictor from ...model.is_hrnet_model import DistMapsHRNetModel class BRSBasePredictor(BasePredictor): def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): super().__init__(model, device, **kwargs) self.optimize_after_n_clicks = optimize_after_n_clicks self.opt_functor = opt_functor self.opt_data = None self.input_data = None def set_input_image(self, image_nd): super().set_input_image(image_nd) self.opt_data = None self.input_data = None def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) for list_indx, clicks_list in enumerate(clicks_lists): for click in clicks_list: y, x = click.coords y, x = int(round(y)), int(round(x)) y1, x1 = y - radius, x - radius y2, x2 = y + radius + 1, x + radius + 1 if click.is_positive: pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True else: neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True with torch.no_grad(): pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) return pos_clicks_map, neg_clicks_map def get_states(self): return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} def set_states(self, states): self._set_transform_states(states['transform_states']) self.opt_data = states['opt_data'] class FeatureBRSPredictor(BRSBasePredictor): def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None if self.insertion_mode == 'after_deeplab': self.num_channels = model.feature_extractor.ch elif self.insertion_mode == 'after_c4': self.num_channels = model.feature_extractor.aspp_in_channels elif self.insertion_mode == 'after_aspp': self.num_channels = model.feature_extractor.ch + 32 else: raise NotImplementedError def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): scale = scale.view(bs, -1, 1, 1) bias = bias.view(bs, -1, 1, 1) if self.with_flip: scale = scale.repeat(2, 1, 1, 1) bias = bias.repeat(2, 1, 1, 1) scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias if self.insertion_mode == 'after_c4': x = self.net.feature_extractor.aspp(scaled_backbone_features) x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], align_corners=True) x = torch.cat((x, self._c1_features), dim=1) scaled_backbone_features = self.net.feature_extractor.head(x) elif self.insertion_mode == 'after_aspp': scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) pred_logits = self.net.head(scaled_backbone_features) pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) return pred_logits self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, **self.opt_functor.optimizer_params) self.opt_data = opt_result[0] with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits def _get_head_input(self, image_nd, points): with torch.no_grad(): coord_features = self.net.dist_maps(image_nd, points) x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': c1, _, c3, c4 = self.net.feature_extractor.backbone(x) c1 = self.net.feature_extractor.skip_project(c1) if self.insertion_mode == 'after_aspp': x = self.net.feature_extractor.aspp(c4) x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) x = torch.cat((x, c1), dim=1) backbone_features = x else: backbone_features = c4 self._c1_features = c1 else: backbone_features = self.net.feature_extractor(x)[0] return backbone_features class HRNetFeatureBRSPredictor(BRSBasePredictor): def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None if self.insertion_mode == 'A': self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) elif self.insertion_mode == 'C': self.num_channels = 2 * model.feature_extractor.ocr_width else: raise NotImplementedError def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): scale = scale.view(bs, -1, 1, 1) bias = bias.view(bs, -1, 1, 1) if self.with_flip: scale = scale.repeat(2, 1, 1, 1) bias = bias.repeat(2, 1, 1, 1) scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias if self.insertion_mode == 'A': out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) feats = self.net.feature_extractor.ocr_distri_head(feats, context) pred_logits = self.net.feature_extractor.cls_head(feats) elif self.insertion_mode == 'C': pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) else: raise NotImplementedError pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) return pred_logits self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, **self.opt_functor.optimizer_params) self.opt_data = opt_result[0] with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits def _get_head_input(self, image_nd, points): with torch.no_grad(): coord_features = self.net.dist_maps(image_nd, points) x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) feats = self.net.feature_extractor.compute_hrnet_feats(x) if self.insertion_mode == 'A': backbone_features = feats elif self.insertion_mode == 'C': out_aux = self.net.feature_extractor.aux_head(feats) feats = self.net.feature_extractor.conv3x3_ocr(feats) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) else: raise NotImplementedError return backbone_features class InputBRSPredictor(BRSBasePredictor): def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.optimize_target = optimize_target def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) if self.opt_data is None or is_image_changed: opt_channels = 2 if self.optimize_target == 'dmaps' else 3 bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), device=self.device, dtype=torch.float32) def get_prediction_logits(opt_bias): input_image = image_nd if self.optimize_target == 'rgb': input_image = input_image + opt_bias dmaps = self.net.dist_maps(input_image, points_nd) if self.optimize_target == 'dmaps': dmaps = dmaps + opt_bias x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) if self.optimize_target == 'all': x = x + opt_bias if isinstance(self.net, DistMapsHRNetModel): pred_logits = self.net.feature_extractor(x)[0] else: backbone_features = self.net.feature_extractor(x) pred_logits = self.net.head(backbone_features[0]) pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) return pred_logits self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, shape=self.opt_data.shape) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), **self.opt_functor.optimizer_params) self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits ================================================ FILE: XMem/inference/interact/fbrs/inference/predictors/brs_functors.py ================================================ import torch import numpy as np from ...model.metrics import _compute_iou from .brs_losses import BRSMaskLoss class BaseOptimizer: def __init__(self, optimizer_params, prob_thresh=0.49, reg_weight=1e-3, min_iou_diff=0.01, brs_loss=BRSMaskLoss(), with_flip=False, flip_average=False, **kwargs): self.brs_loss = brs_loss self.optimizer_params = optimizer_params self.prob_thresh = prob_thresh self.reg_weight = reg_weight self.min_iou_diff = min_iou_diff self.with_flip = with_flip self.flip_average = flip_average self.best_prediction = None self._get_prediction_logits = None self._opt_shape = None self._best_loss = None self._click_masks = None self._last_mask = None self.device = None def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): self.best_prediction = None self._get_prediction_logits = get_prediction_logits self._click_masks = (pos_mask, neg_mask) self._opt_shape = shape self._last_mask = None self.device = device def __call__(self, x): opt_params = torch.from_numpy(x).float().to(self.device) opt_params.requires_grad_(True) with torch.enable_grad(): opt_vars, reg_loss = self.unpack_opt_params(opt_params) result_before_sigmoid = self._get_prediction_logits(*opt_vars) result = torch.sigmoid(result_before_sigmoid) pos_mask, neg_mask = self._click_masks if self.with_flip and self.flip_average: result, result_flipped = torch.chunk(result, 2, dim=0) result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) loss = loss + reg_loss f_val = loss.detach().cpu().numpy() if self.best_prediction is None or f_val < self._best_loss: self.best_prediction = result_before_sigmoid.detach() self._best_loss = f_val if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: return [f_val, np.zeros_like(x)] current_mask = result > self.prob_thresh if self._last_mask is not None and self.min_iou_diff > 0: diff_iou = _compute_iou(current_mask, self._last_mask) if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: return [f_val, np.zeros_like(x)] self._last_mask = current_mask loss.backward() f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32) return [f_val, f_grad] def unpack_opt_params(self, opt_params): raise NotImplementedError class InputOptimizer(BaseOptimizer): def unpack_opt_params(self, opt_params): opt_params = opt_params.view(self._opt_shape) if self.with_flip: opt_params_flipped = torch.flip(opt_params, dims=[3]) opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) reg_loss = self.reg_weight * torch.sum(opt_params**2) return (opt_params,), reg_loss class ScaleBiasOptimizer(BaseOptimizer): def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): super().__init__(*args, **kwargs) self.scale_act = scale_act self.reg_bias_weight = reg_bias_weight def unpack_opt_params(self, opt_params): scale, bias = torch.chunk(opt_params, 2, dim=0) reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) if self.scale_act == 'tanh': scale = torch.tanh(scale) elif self.scale_act == 'sin': scale = torch.sin(scale) return (1 + scale, bias), reg_loss ================================================ FILE: XMem/inference/interact/fbrs/inference/predictors/brs_losses.py ================================================ import torch from ...model.losses import SigmoidBinaryCrossEntropyLoss class BRSMaskLoss(torch.nn.Module): def __init__(self, eps=1e-5): super().__init__() self._eps = eps def forward(self, result, pos_mask, neg_mask): pos_diff = (1 - result) * pos_mask pos_target = torch.sum(pos_diff ** 2) pos_target = pos_target / (torch.sum(pos_mask) + self._eps) neg_diff = result * neg_mask neg_target = torch.sum(neg_diff ** 2) neg_target = neg_target / (torch.sum(neg_mask) + self._eps) loss = pos_target + neg_target with torch.no_grad(): f_max_pos = torch.max(torch.abs(pos_diff)).item() f_max_neg = torch.max(torch.abs(neg_diff)).item() return loss, f_max_pos, f_max_neg class OracleMaskLoss(torch.nn.Module): def __init__(self): super().__init__() self.gt_mask = None self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) self.predictor = None self.history = [] def set_gt_mask(self, gt_mask): self.gt_mask = gt_mask self.history = [] def forward(self, result, pos_mask, neg_mask): gt_mask = self.gt_mask.to(result.device) if self.predictor.object_roi is not None: r1, r2, c1, c2 = self.predictor.object_roi[:4] gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) if result.shape[0] == 2: gt_mask_flipped = torch.flip(gt_mask, dims=[3]) gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) loss = self.loss(result, gt_mask) self.history.append(loss.detach().cpu().numpy()[0]) if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: return 0, 0, 0 return loss, 1.0, 1.0 ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/__init__.py ================================================ from .base import SigmoidForPred from .flip import AddHorizontalFlip from .zoom_in import ZoomIn from .limit_longest_side import LimitLongestSide from .crops import Crops ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/base.py ================================================ import torch class BaseTransform(object): def __init__(self): self.image_changed = False def transform(self, image_nd, clicks_lists): raise NotImplementedError def inv_transform(self, prob_map): raise NotImplementedError def reset(self): raise NotImplementedError def get_state(self): raise NotImplementedError def set_state(self, state): raise NotImplementedError class SigmoidForPred(BaseTransform): def transform(self, image_nd, clicks_lists): return image_nd, clicks_lists def inv_transform(self, prob_map): return torch.sigmoid(prob_map) def reset(self): pass def get_state(self): return None def set_state(self, state): pass ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/crops.py ================================================ import math import torch import numpy as np from ...inference.clicker import Click from .base import BaseTransform class Crops(BaseTransform): def __init__(self, crop_size=(320, 480), min_overlap=0.2): super().__init__() self.crop_height, self.crop_width = crop_size self.min_overlap = min_overlap self.x_offsets = None self.y_offsets = None self._counts = None def transform(self, image_nd, clicks_lists): assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 image_height, image_width = image_nd.shape[2:4] self._counts = None if image_height < self.crop_height or image_width < self.crop_width: return image_nd, clicks_lists self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) self._counts = np.zeros((image_height, image_width)) image_crops = [] for dy in self.y_offsets: for dx in self.x_offsets: self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] image_crops.append(image_crop) image_crops = torch.cat(image_crops, dim=0) self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) clicks_list = clicks_lists[0] clicks_lists = [] for dy in self.y_offsets: for dx in self.x_offsets: crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] clicks_lists.append(crop_clicks) return image_crops, clicks_lists def inv_transform(self, prob_map): if self._counts is None: return prob_map new_prob_map = torch.zeros((1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device) crop_indx = 0 for dy in self.y_offsets: for dx in self.x_offsets: new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] crop_indx += 1 new_prob_map = torch.div(new_prob_map, self._counts) return new_prob_map def get_state(self): return self.x_offsets, self.y_offsets, self._counts def set_state(self, state): self.x_offsets, self.y_offsets, self._counts = state def reset(self): self.x_offsets = None self.y_offsets = None self._counts = None def get_offsets(length, crop_size, min_overlap_ratio=0.2): if length == crop_size: return [0] N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) N = math.ceil(N) overlap_ratio = (N - length / crop_size) / (N - 1) overlap_width = int(crop_size * overlap_ratio) offsets = [0] for i in range(1, N): new_offset = offsets[-1] + crop_size - overlap_width if new_offset + crop_size > length: new_offset = length - crop_size offsets.append(new_offset) return offsets ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/flip.py ================================================ import torch from ..clicker import Click from .base import BaseTransform class AddHorizontalFlip(BaseTransform): def transform(self, image_nd, clicks_lists): assert len(image_nd.shape) == 4 image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) image_width = image_nd.shape[3] clicks_lists_flipped = [] for clicks_list in clicks_lists: clicks_list_flipped = [Click(is_positive=click.is_positive, coords=(click.coords[0], image_width - click.coords[1] - 1)) for click in clicks_list] clicks_lists_flipped.append(clicks_list_flipped) clicks_lists = clicks_lists + clicks_lists_flipped return image_nd, clicks_lists def inv_transform(self, prob_map): assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 num_maps = prob_map.shape[0] // 2 prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) def get_state(self): return None def set_state(self, state): pass def reset(self): pass ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py ================================================ from .zoom_in import ZoomIn, get_roi_image_nd class LimitLongestSide(ZoomIn): def __init__(self, max_size=800): super().__init__(target_size=max_size, skip_clicks=0) def transform(self, image_nd, clicks_lists): assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 image_max_size = max(image_nd.shape[2:4]) self.image_changed = False if image_max_size <= self.target_size: return image_nd, clicks_lists self._input_image = image_nd self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) self.image_changed = True tclicks_lists = [self._transform_clicks(clicks_lists[0])] return self._roi_image, tclicks_lists ================================================ FILE: XMem/inference/interact/fbrs/inference/transforms/zoom_in.py ================================================ import torch from ..clicker import Click from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox from .base import BaseTransform class ZoomIn(BaseTransform): def __init__(self, target_size=400, skip_clicks=1, expansion_ratio=1.4, min_crop_size=200, recompute_thresh_iou=0.5, prob_thresh=0.50): super().__init__() self.target_size = target_size self.min_crop_size = min_crop_size self.skip_clicks = skip_clicks self.expansion_ratio = expansion_ratio self.recompute_thresh_iou = recompute_thresh_iou self.prob_thresh = prob_thresh self._input_image_shape = None self._prev_probs = None self._object_roi = None self._roi_image = None def transform(self, image_nd, clicks_lists): assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 self.image_changed = False clicks_list = clicks_lists[0] if len(clicks_list) <= self.skip_clicks: return image_nd, clicks_lists self._input_image_shape = image_nd.shape current_object_roi = None if self._prev_probs is not None: current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] if current_pred_mask.sum() > 0: current_object_roi = get_object_roi(current_pred_mask, clicks_list, self.expansion_ratio, self.min_crop_size) if current_object_roi is None: return image_nd, clicks_lists update_object_roi = False if self._object_roi is None: update_object_roi = True elif not check_object_roi(self._object_roi, clicks_list): update_object_roi = True elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: update_object_roi = True if update_object_roi: self._object_roi = current_object_roi self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) self.image_changed = True tclicks_lists = [self._transform_clicks(clicks_list)] return self._roi_image.to(image_nd.device), tclicks_lists def inv_transform(self, prob_map): if self._object_roi is None: self._prev_probs = prob_map.cpu().numpy() return prob_map assert prob_map.shape[0] == 1 rmin, rmax, cmin, cmax = self._object_roi prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), mode='bilinear', align_corners=True) if self._prev_probs is not None: new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map else: new_prob_map = prob_map self._prev_probs = new_prob_map.cpu().numpy() return new_prob_map def check_possible_recalculation(self): if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: return False pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] if pred_mask.sum() > 0: possible_object_roi = get_object_roi(pred_mask, [], self.expansion_ratio, self.min_crop_size) image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) if get_bbox_iou(possible_object_roi, image_roi) < 0.50: return True return False def get_state(self): roi_image = self._roi_image.cpu() if self._roi_image is not None else None return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed def set_state(self, state): self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state def reset(self): self._input_image_shape = None self._object_roi = None self._prev_probs = None self._roi_image = None self.image_changed = False def _transform_clicks(self, clicks_list): if self._object_roi is None: return clicks_list rmin, rmax, cmin, cmax = self._object_roi crop_height, crop_width = self._roi_image.shape[2:] transformed_clicks = [] for click in clicks_list: new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c))) return transformed_clicks def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): pred_mask = pred_mask.copy() for click in clicks_list: if click.is_positive: pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 bbox = get_bbox_from_mask(pred_mask) bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) h, w = pred_mask.shape[0], pred_mask.shape[1] bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) return bbox def get_roi_image_nd(image_nd, object_roi, target_size): rmin, rmax, cmin, cmax = object_roi height = rmax - rmin + 1 width = cmax - cmin + 1 if isinstance(target_size, tuple): new_height, new_width = target_size else: scale = target_size / max(height, width) new_height = int(round(height * scale)) new_width = int(round(width * scale)) with torch.no_grad(): roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), mode='bilinear', align_corners=True) return roi_image_nd def check_object_roi(object_roi, clicks_list): for click in clicks_list: if click.is_positive: if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: return False if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: return False return True ================================================ FILE: XMem/inference/interact/fbrs/inference/utils.py ================================================ from datetime import timedelta from pathlib import Path import torch import numpy as np from ..model.is_deeplab_model import get_deeplab_model from ..model.is_hrnet_model import get_hrnet_model def get_time_metrics(all_ious, elapsed_time): n_images = len(all_ious) n_clicks = sum(map(len, all_ious)) mean_spc = elapsed_time / n_clicks mean_spi = elapsed_time / n_images return mean_spc, mean_spi def load_is_model(checkpoint, device, backbone='auto', **kwargs): if isinstance(checkpoint, (str, Path)): state_dict = torch.load(checkpoint, map_location='cpu') else: state_dict = checkpoint if backbone == 'auto': for k in state_dict.keys(): if 'feature_extractor.stage2.0.branches' in k: return load_hrnet_is_model(state_dict, device, backbone, **kwargs) return load_deeplab_is_model(state_dict, device, backbone, **kwargs) elif 'resnet' in backbone: return load_deeplab_is_model(state_dict, device, backbone, **kwargs) elif 'hrnet' in backbone: return load_hrnet_is_model(state_dict, device, backbone, **kwargs) else: raise NotImplementedError('Unknown backbone') def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256, small=False, cpu_dist_maps=False, norm_radius=260): if backbone == 'auto': num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x]) small = num_fe_weights < 1800 ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k] assert len(ocr_f_down) == 1 ocr_width = ocr_f_down[0].shape[0] s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k] assert len(s2_conv1_w) == 1 width = s2_conv1_w[0].shape[0] model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small, with_aux_output=False, cpu_dist_maps=cpu_dist_maps, norm_radius=norm_radius) model.load_state_dict(state_dict, strict=False) for param in model.parameters(): param.requires_grad = False model.to(device) model.eval() return model def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2, cpu_dist_maps=False, norm_radius=260): if backbone == 'auto': num_backbone_params = len([x for x in state_dict.keys() if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)]) if num_backbone_params <= 181: backbone = 'resnet34' elif num_backbone_params <= 276: backbone = 'resnet50' elif num_backbone_params <= 531: backbone = 'resnet101' else: raise NotImplementedError('Unknown backbone') if 'aspp_dropout' in state_dict: aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy()) else: aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0] deeplab_ch = aspp_project_weight.size(0) if deeplab_ch == 256: aspp_dropout = 0.5 model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch, aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps, norm_radius=norm_radius) model.load_state_dict(state_dict, strict=False) for param in model.parameters(): param.requires_grad = False model.to(device) model.eval() return model def get_iou(gt_mask, pred_mask, ignore_label=-1): ignore_gt_mask_inv = gt_mask != ignore_label obj_gt_mask = gt_mask == 1 intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() return intersection / union def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): def _get_noc(iou_arr, iou_thr): vals = iou_arr >= iou_thr return np.argmax(vals) + 1 if np.any(vals) else max_clicks noc_list = [] over_max_list = [] for iou_thr in iou_thrs: scores_arr = np.array([_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int32) score = scores_arr.mean() over_max = (scores_arr == max_clicks).sum() noc_list.append(score) over_max_list.append(over_max) return noc_list, over_max_list def find_checkpoint(weights_folder, checkpoint_name): weights_folder = Path(weights_folder) if ':' in checkpoint_name: model_name, checkpoint_name = checkpoint_name.split(':') models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] assert len(models_candidates) == 1 model_folder = models_candidates[0] else: model_folder = weights_folder if checkpoint_name.endswith('.pth'): if Path(checkpoint_name).exists(): checkpoint_path = checkpoint_name else: checkpoint_path = weights_folder / checkpoint_name else: model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) assert len(model_checkpoints) == 1 checkpoint_path = model_checkpoints[0] return str(checkpoint_path) def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, n_clicks=20, model_name=None): table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' f'{"SPC,s":^7}|{"Time":^9}|') row_width = len(table_header) header = f'Eval results for model: {model_name}\n' if model_name is not None else '' header += '-' * row_width + '\n' header += table_header + '\n' + '-' * row_width eval_time = str(timedelta(seconds=int(elapsed_time))) table_row = f'|{brs_type:^13}|{dataset_name:^11}|' table_row += f'{noc_list[0]:^9.2f}|' table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' return header, table_row ================================================ FILE: XMem/inference/interact/fbrs/model/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/model/initializer.py ================================================ import torch import torch.nn as nn import numpy as np class Initializer(object): def __init__(self, local_init=True, gamma=None): self.local_init = local_init self.gamma = gamma def __call__(self, m): if getattr(m, '__initialized', False): return if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: if m.weight is not None: self._init_gamma(m.weight.data) if m.bias is not None: self._init_beta(m.bias.data) else: if getattr(m, 'weight', None) is not None: self._init_weight(m.weight.data) if getattr(m, 'bias', None) is not None: self._init_bias(m.bias.data) if self.local_init: object.__setattr__(m, '__initialized', True) def _init_weight(self, data): nn.init.uniform_(data, -0.07, 0.07) def _init_bias(self, data): nn.init.constant_(data, 0) def _init_gamma(self, data): if self.gamma is None: nn.init.constant_(data, 1.0) else: nn.init.normal_(data, 1.0, self.gamma) def _init_beta(self, data): nn.init.constant_(data, 0) class Bilinear(Initializer): def __init__(self, scale, groups, in_channels, **kwargs): super().__init__(**kwargs) self.scale = scale self.groups = groups self.in_channels = in_channels def _init_weight(self, data): """Reset the weight and bias.""" bilinear_kernel = self.get_bilinear_kernel(self.scale) weight = torch.zeros_like(data) for i in range(self.in_channels): if self.groups == 1: j = i else: j = 0 weight[i, j] = bilinear_kernel data[:] = weight @staticmethod def get_bilinear_kernel(scale): """Generate a bilinear upsampling kernel.""" kernel_size = 2 * scale - scale % 2 scale = (kernel_size + 1) // 2 center = scale - 0.5 * (1 + kernel_size % 2) og = np.ogrid[:kernel_size, :kernel_size] kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) return torch.tensor(kernel, dtype=torch.float32) class XavierGluon(Initializer): def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): super().__init__(**kwargs) self.rnd_type = rnd_type self.factor_type = factor_type self.magnitude = float(magnitude) def _init_weight(self, arr): fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) if self.factor_type == 'avg': factor = (fan_in + fan_out) / 2.0 elif self.factor_type == 'in': factor = fan_in elif self.factor_type == 'out': factor = fan_out else: raise ValueError('Incorrect factor type') scale = np.sqrt(self.magnitude / factor) if self.rnd_type == 'uniform': nn.init.uniform_(arr, -scale, scale) elif self.rnd_type == 'gaussian': nn.init.normal_(arr, 0, scale) else: raise ValueError('Unknown random type') ================================================ FILE: XMem/inference/interact/fbrs/model/is_deeplab_model.py ================================================ import torch import torch.nn as nn from .ops import DistMaps from .modeling.deeplab_v3 import DeepLabV3Plus from .modeling.basic_blocks import SepConvHead def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, norm_layer=nn.BatchNorm2d, backbone_norm_layer=None, use_rgb_conv=True, cpu_dist_maps=False, norm_radius=260): model = DistMapsModel( feature_extractor=DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer), head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, num_layers=2, norm_layer=norm_layer), use_rgb_conv=use_rgb_conv, norm_layer=norm_layer, norm_radius=norm_radius, cpu_dist_maps=cpu_dist_maps ) return model class DistMapsModel(nn.Module): def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True, cpu_dist_maps=False, norm_radius=260): super(DistMapsModel, self).__init__() if use_rgb_conv: self.rgb_conv = nn.Sequential( nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), nn.LeakyReLU(negative_slope=0.2), norm_layer(8), nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), ) else: self.rgb_conv = None self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps) self.feature_extractor = feature_extractor self.head = head def forward(self, image, points): coord_features = self.dist_maps(image, points) if self.rgb_conv is not None: x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) else: c1, c2 = torch.chunk(coord_features, 2, dim=1) c3 = torch.ones_like(c1) coord_features = torch.cat((c1, c2, c3), dim=1) x = 0.8 * image * coord_features + 0.2 * image backbone_features = self.feature_extractor(x) instance_out = self.head(backbone_features[0]) instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], mode='bilinear', align_corners=True) return {'instances': instance_out} def load_weights(self, path_to_weights): current_state_dict = self.state_dict() new_state_dict = torch.load(path_to_weights, map_location='cpu') current_state_dict.update(new_state_dict) self.load_state_dict(current_state_dict) def get_trainable_params(self): backbone_params = nn.ParameterList() other_params = nn.ParameterList() for name, param in self.named_parameters(): if param.requires_grad: if 'backbone' in name: backbone_params.append(param) else: other_params.append(param) return backbone_params, other_params ================================================ FILE: XMem/inference/interact/fbrs/model/is_hrnet_model.py ================================================ import torch import torch.nn as nn from .ops import DistMaps from .modeling.hrnet_ocr import HighResolutionNet def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260, use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False, norm_layer=nn.BatchNorm2d): model = DistMapsHRNetModel( feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small, num_classes=1, norm_layer=norm_layer), use_rgb_conv=use_rgb_conv, with_aux_output=with_aux_output, norm_layer=norm_layer, norm_radius=norm_radius, cpu_dist_maps=cpu_dist_maps ) return model class DistMapsHRNetModel(nn.Module): def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False, norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False): super(DistMapsHRNetModel, self).__init__() self.with_aux_output = with_aux_output if use_rgb_conv: self.rgb_conv = nn.Sequential( nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), nn.LeakyReLU(negative_slope=0.2), norm_layer(8), nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), ) else: self.rgb_conv = None self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps) self.feature_extractor = feature_extractor def forward(self, image, points): coord_features = self.dist_maps(image, points) if self.rgb_conv is not None: x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) else: c1, c2 = torch.chunk(coord_features, 2, dim=1) c3 = torch.ones_like(c1) coord_features = torch.cat((c1, c2, c3), dim=1) x = 0.8 * image * coord_features + 0.2 * image feature_extractor_out = self.feature_extractor(x) instance_out = feature_extractor_out[0] instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], mode='bilinear', align_corners=True) outputs = {'instances': instance_out} if self.with_aux_output: instance_aux_out = feature_extractor_out[1] instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:], mode='bilinear', align_corners=True) outputs['instances_aux'] = instance_aux_out return outputs def load_weights(self, path_to_weights): current_state_dict = self.state_dict() new_state_dict = torch.load(path_to_weights) current_state_dict.update(new_state_dict) self.load_state_dict(current_state_dict) def get_trainable_params(self): backbone_params = nn.ParameterList() other_params = nn.ParameterList() other_params_keys = [] nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head'] for name, param in self.named_parameters(): if param.requires_grad: if any(x in name for x in nonbackbone_keywords): other_params.append(param) other_params_keys.append(name) else: backbone_params.append(param) print('Nonbackbone params:', sorted(other_params_keys)) return backbone_params, other_params ================================================ FILE: XMem/inference/interact/fbrs/model/losses.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ..utils import misc class NormalizedFocalLossSigmoid(nn.Module): def __init__(self, axis=-1, alpha=0.25, gamma=2, from_logits=False, batch_axis=0, weight=None, size_average=True, detach_delimeter=True, eps=1e-12, scale=1.0, ignore_label=-1): super(NormalizedFocalLossSigmoid, self).__init__() self._axis = axis self._alpha = alpha self._gamma = gamma self._ignore_label = ignore_label self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis self._scale = scale self._from_logits = from_logits self._eps = eps self._size_average = size_average self._detach_delimeter = detach_delimeter self._k_sum = 0 def forward(self, pred, label, sample_weight=None): one_hot = label > 0 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) pt = torch.where(one_hot, pred, 1 - pred) pt = torch.where(sample_weight, pt, torch.ones_like(pt)) beta = (1 - pt) ** self._gamma sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) mult = sw_sum / (beta_sum + self._eps) if self._detach_delimeter: mult = mult.detach() beta = beta * mult ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() if np.any(ignore_area == 0): self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) loss = self._weight * (loss * sample_weight) if self._size_average: bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) return self._scale * loss def log_states(self, sw, name, global_step): sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) class FocalLoss(nn.Module): def __init__(self, axis=-1, alpha=0.25, gamma=2, from_logits=False, batch_axis=0, weight=None, num_class=None, eps=1e-9, size_average=True, scale=1.0): super(FocalLoss, self).__init__() self._axis = axis self._alpha = alpha self._gamma = gamma self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis self._scale = scale self._num_class = num_class self._from_logits = from_logits self._eps = eps self._size_average = size_average def forward(self, pred, label, sample_weight=None): if not self._from_logits: pred = F.sigmoid(pred) one_hot = label > 0 pt = torch.where(one_hot, pred, 1 - pred) t = label != -1 alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t) beta = (1 - pt) ** self._gamma loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) sample_weight = label != -1 loss = self._weight * (loss * sample_weight) if self._size_average: tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) return self._scale * loss class SigmoidBinaryCrossEntropyLoss(nn.Module): def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): super(SigmoidBinaryCrossEntropyLoss, self).__init__() self._from_sigmoid = from_sigmoid self._ignore_label = ignore_label self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis def forward(self, pred, label): label = label.view(pred.size()) sample_weight = label != self._ignore_label label = torch.where(sample_weight, label, torch.zeros_like(label)) if not self._from_sigmoid: loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) else: eps = 1e-12 loss = -(torch.log(pred + eps) * label + torch.log(1. - pred + eps) * (1. - label)) loss = self._weight * (loss * sample_weight) return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) ================================================ FILE: XMem/inference/interact/fbrs/model/metrics.py ================================================ import torch import numpy as np from ..utils import misc class TrainMetric(object): def __init__(self, pred_outputs, gt_outputs): self.pred_outputs = pred_outputs self.gt_outputs = gt_outputs def update(self, *args, **kwargs): raise NotImplementedError def get_epoch_value(self): raise NotImplementedError def reset_epoch_stats(self): raise NotImplementedError def log_states(self, sw, tag_prefix, global_step): pass @property def name(self): return type(self).__name__ class AdaptiveIoU(TrainMetric): def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, ignore_label=-1, from_logits=True, pred_output='instances', gt_output='instances'): super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) self._ignore_label = ignore_label self._from_logits = from_logits self._iou_thresh = init_thresh self._thresh_step = thresh_step self._thresh_beta = thresh_beta self._iou_beta = iou_beta self._ema_iou = 0.0 self._epoch_iou_sum = 0.0 self._epoch_batch_count = 0 def update(self, pred, gt): gt_mask = gt > 0 if self._from_logits: pred = torch.sigmoid(pred) gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() if np.all(gt_mask_area == 0): return ignore_mask = gt == self._ignore_label max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() best_thresh = self._iou_thresh for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() if temp_iou > max_iou: max_iou = temp_iou best_thresh = t self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou self._epoch_iou_sum += max_iou self._epoch_batch_count += 1 def get_epoch_value(self): if self._epoch_batch_count > 0: return self._epoch_iou_sum / self._epoch_batch_count else: return 0.0 def reset_epoch_stats(self): self._epoch_iou_sum = 0.0 self._epoch_batch_count = 0 def log_states(self, sw, tag_prefix, global_step): sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) @property def iou_thresh(self): return self._iou_thresh def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): if ignore_mask is not None: pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() nonzero = union > 0 iou = intersection[nonzero] / union[nonzero] if not keep_ignore: return iou else: result = np.full_like(intersection, -1) result[nonzero] = iou return result ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/basic_blocks.py ================================================ import torch.nn as nn from ...model import ops class ConvHead(nn.Module): def __init__(self, out_channels, in_channels=32, num_layers=1, kernel_size=3, padding=1, norm_layer=nn.BatchNorm2d): super(ConvHead, self).__init__() convhead = [] for i in range(num_layers): convhead.extend([ nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), nn.ReLU(), norm_layer(in_channels) if norm_layer is not None else nn.Identity() ]) convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) self.convhead = nn.Sequential(*convhead) def forward(self, *inputs): return self.convhead(inputs[0]) class SepConvHead(nn.Module): def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, norm_layer=nn.BatchNorm2d): super(SepConvHead, self).__init__() sepconvhead = [] for i in range(num_layers): sepconvhead.append( SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, out_channels=mid_channels, dw_kernel=kernel_size, dw_padding=padding, norm_layer=norm_layer, activation='relu') ) if dropout_ratio > 0 and dropout_indx == i: sepconvhead.append(nn.Dropout(dropout_ratio)) sepconvhead.append( nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) ) self.layers = nn.Sequential(*sepconvhead) def forward(self, *inputs): x = inputs[0] return self.layers(x) class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, activation=None, use_bias=False, norm_layer=None): super(SeparableConv2d, self).__init__() _activation = ops.select_activation_function(activation) self.body = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, padding=dw_padding, bias=use_bias, groups=in_channels), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), norm_layer(out_channels) if norm_layer is not None else nn.Identity(), _activation() ) def forward(self, x): return self.body(x) ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py ================================================ from contextlib import ExitStack import torch from torch import nn import torch.nn.functional as F from .basic_blocks import SeparableConv2d from .resnet import ResNetBackbone from ...model import ops class DeepLabV3Plus(nn.Module): def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, backbone_norm_layer=None, ch=256, project_dropout=0.5, inference_mode=False, **kwargs): super(DeepLabV3Plus, self).__init__() if backbone_norm_layer is None: backbone_norm_layer = norm_layer self.backbone_name = backbone self.norm_layer = norm_layer self.backbone_norm_layer = backbone_norm_layer self.inference_mode = False self.ch = ch self.aspp_in_channels = 2048 self.skip_project_in_channels = 256 # layer 1 out_channels self._kwargs = kwargs if backbone == 'resnet34': self.aspp_in_channels = 512 self.skip_project_in_channels = 64 self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, norm_layer=self.backbone_norm_layer, **kwargs) self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, norm_layer=self.norm_layer) self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) self.aspp = _ASPP(in_channels=self.aspp_in_channels, atrous_rates=[12, 24, 36], out_channels=ch, project_dropout=project_dropout, norm_layer=self.norm_layer) if inference_mode: self.set_prediction_mode() def load_pretrained_weights(self): pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, norm_layer=self.backbone_norm_layer, **self._kwargs) backbone_state_dict = self.backbone.state_dict() pretrained_state_dict = pretrained.state_dict() backbone_state_dict.update(pretrained_state_dict) self.backbone.load_state_dict(backbone_state_dict) if self.inference_mode: for param in self.backbone.parameters(): param.requires_grad = False def set_prediction_mode(self): self.inference_mode = True self.eval() def forward(self, x): with ExitStack() as stack: if self.inference_mode: stack.enter_context(torch.no_grad()) c1, _, c3, c4 = self.backbone(x) c1 = self.skip_project(c1) x = self.aspp(c4) x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) x = torch.cat((x, c1), dim=1) x = self.head(x) return x, class _SkipProject(nn.Module): def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): super(_SkipProject, self).__init__() _activation = ops.select_activation_function("relu") self.skip_project = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), norm_layer(out_channels), _activation() ) def forward(self, x): return self.skip_project(x) class _DeepLabHead(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): super(_DeepLabHead, self).__init__() self.block = nn.Sequential( SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, dw_padding=1, activation='relu', norm_layer=norm_layer), SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, dw_padding=1, activation='relu', norm_layer=norm_layer), nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) ) def forward(self, x): return self.block(x) class _ASPP(nn.Module): def __init__(self, in_channels, atrous_rates, out_channels=256, project_dropout=0.5, norm_layer=nn.BatchNorm2d): super(_ASPP, self).__init__() b0 = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), norm_layer(out_channels), nn.ReLU() ) rate1, rate2, rate3 = tuple(atrous_rates) b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) project = [ nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, kernel_size=1, bias=False), norm_layer(out_channels), nn.ReLU() ] if project_dropout > 0: project.append(nn.Dropout(project_dropout)) self.project = nn.Sequential(*project) def forward(self, x): x = torch.cat([block(x) for block in self.concurent], dim=1) return self.project(x) class _AsppPooling(nn.Module): def __init__(self, in_channels, out_channels, norm_layer): super(_AsppPooling, self).__init__() self.gap = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), norm_layer(out_channels), nn.ReLU() ) def forward(self, x): pool = self.gap(x) return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): block = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=atrous_rate, dilation=atrous_rate, bias=False), norm_layer(out_channels), nn.ReLU() ) return block ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch._utils import torch.nn.functional as F from .ocr import SpatialOCR_Module, SpatialGather_Module from .resnetv1b import BasicBlockV1b, BottleneckV1b relu_inplace = True class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method,multi_scale_output=True, norm_layer=nn.BatchNorm2d, align_corners=True): super(HighResolutionModule, self).__init__() self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method self.num_branches = num_branches self.norm_layer = norm_layer self.align_corners = align_corners self.multi_scale_output = multi_scale_output self.branches = self._make_branches( num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=relu_inplace) def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( num_branches, len(num_blocks)) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( num_branches, len(num_channels)) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( num_branches, len(num_inchannels)) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), self.norm_layer(num_channels[branch_index] * block.expansion), ) layers = [] layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample=downsample, norm_layer=self.norm_layer)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], norm_layer=self.norm_layer)) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(in_channels=num_inchannels[j], out_channels=num_inchannels[i], kernel_size=1, bias=False), self.norm_layer(num_inchannels[i]))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, kernel_size=3, stride=2, padding=1, bias=False), self.norm_layer(num_outchannels_conv3x3))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, kernel_size=3, stride=2, padding=1, bias=False), self.norm_layer(num_outchannels_conv3x3), nn.ReLU(inplace=relu_inplace))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def get_num_inchannels(self): return self.num_inchannels def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): if i == j: y = y + x[j] elif j > i: width_output = x[i].shape[-1] height_output = x[i].shape[-2] y = y + F.interpolate( self.fuse_layers[i][j](x[j]), size=[height_output, width_output], mode='bilinear', align_corners=self.align_corners) else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse class HighResolutionNet(nn.Module): def __init__(self, width, num_classes, ocr_width=256, small=False, norm_layer=nn.BatchNorm2d, align_corners=True): super(HighResolutionNet, self).__init__() self.norm_layer = norm_layer self.width = width self.ocr_width = ocr_width self.align_corners = align_corners self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = norm_layer(64) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn2 = norm_layer(64) self.relu = nn.ReLU(inplace=relu_inplace) num_blocks = 2 if small else 4 stage1_num_channels = 64 self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels self.stage2_num_branches = 2 num_channels = [width, 2 * width] num_inchannels = [ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer( [stage1_out_channel], num_inchannels) self.stage2, pre_stage_channels = self._make_stage( BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, num_blocks=2 * [num_blocks], num_channels=num_channels) self.stage3_num_branches = 3 num_channels = [width, 2 * width, 4 * width] num_inchannels = [ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer( pre_stage_channels, num_inchannels) self.stage3, pre_stage_channels = self._make_stage( BasicBlockV1b, num_inchannels=num_inchannels, num_modules=3 if small else 4, num_branches=self.stage3_num_branches, num_blocks=3 * [num_blocks], num_channels=num_channels) self.stage4_num_branches = 4 num_channels = [width, 2 * width, 4 * width, 8 * width] num_inchannels = [ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer( pre_stage_channels, num_inchannels) self.stage4, pre_stage_channels = self._make_stage( BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, num_branches=self.stage4_num_branches, num_blocks=4 * [num_blocks], num_channels=num_channels) last_inp_channels = np.int32(np.sum(pre_stage_channels)) ocr_mid_channels = 2 * ocr_width ocr_key_channels = ocr_width self.conv3x3_ocr = nn.Sequential( nn.Conv2d(last_inp_channels, ocr_mid_channels, kernel_size=3, stride=1, padding=1), norm_layer(ocr_mid_channels), nn.ReLU(inplace=relu_inplace), ) self.ocr_gather_head = SpatialGather_Module(num_classes) self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, key_channels=ocr_key_channels, out_channels=ocr_mid_channels, scale=1, dropout=0.05, norm_layer=norm_layer, align_corners=align_corners) self.cls_head = nn.Conv2d( ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) self.aux_head = nn.Sequential( nn.Conv2d(last_inp_channels, last_inp_channels, kernel_size=1, stride=1, padding=0), norm_layer(last_inp_channels), nn.ReLU(inplace=relu_inplace), nn.Conv2d(last_inp_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) def _make_transition_layer( self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=3, stride=1, padding=1, bias=False), self.norm_layer(num_channels_cur_layer[i]), nn.ReLU(inplace=relu_inplace))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d(inchannels, outchannels, kernel_size=3, stride=2, padding=1, bias=False), self.norm_layer(outchannels), nn.ReLU(inplace=relu_inplace))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), self.norm_layer(planes * block.expansion), ) layers = [] layers.append(block(inplanes, planes, stride, downsample=downsample, norm_layer=self.norm_layer)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) return nn.Sequential(*layers) def _make_stage(self, block, num_inchannels, num_modules, num_branches, num_blocks, num_channels, fuse_method='SUM', multi_scale_output=True): modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multi_scale_output and i == num_modules - 1: reset_multi_scale_output = False else: reset_multi_scale_output = True modules.append( HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output, norm_layer=self.norm_layer, align_corners=self.align_corners) ) num_inchannels = modules[-1].get_num_inchannels() return nn.Sequential(*modules), num_inchannels def forward(self, x): feats = self.compute_hrnet_feats(x) out_aux = self.aux_head(feats) feats = self.conv3x3_ocr(feats) context = self.ocr_gather_head(feats, out_aux) feats = self.ocr_distri_head(feats, context) out = self.cls_head(feats) return [out, out_aux] def compute_hrnet_feats(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_num_branches): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_num_branches): if self.transition2[i] is not None: if i < self.stage2_num_branches: x_list.append(self.transition2[i](y_list[i])) else: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_num_branches): if self.transition3[i] is not None: if i < self.stage3_num_branches: x_list.append(self.transition3[i](y_list[i])) else: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) x = self.stage4(x_list) # Upsampling x0_h, x0_w = x[0].size(2), x[0].size(3) x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=self.align_corners) x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=self.align_corners) x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=self.align_corners) return torch.cat([x[0], x1, x2, x3], 1) def load_pretrained_weights(self, pretrained_path=''): model_dict = self.state_dict() if not os.path.exists(pretrained_path): print(f'\nFile "{pretrained_path}" does not exist.') print('You need to specify the correct path to the pre-trained weights.\n' 'You can download the weights for HRNet from the repository:\n' 'https://github.com/HRNet/HRNet-Image-Classification') exit(1) pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in pretrained_dict.items()} print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict)))) print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict)))) pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/ocr.py ================================================ import torch import torch.nn as nn import torch._utils import torch.nn.functional as F class SpatialGather_Module(nn.Module): """ Aggregate the context features according to the initial predicted probability distribution. Employ the soft-weighted method to aggregate the context. """ def __init__(self, cls_num=0, scale=1): super(SpatialGather_Module, self).__init__() self.cls_num = cls_num self.scale = scale def forward(self, feats, probs): batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) probs = probs.view(batch_size, c, -1) feats = feats.view(batch_size, feats.size(1), -1) feats = feats.permute(0, 2, 1) # batch x hw x c probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw ocr_context = torch.matmul(probs, feats) \ .permute(0, 2, 1).unsqueeze(3) # batch x k x c return ocr_context class SpatialOCR_Module(nn.Module): """ Implementation of the OCR module: We aggregate the global object representation to update the representation for each pixel. """ def __init__(self, in_channels, key_channels, out_channels, scale=1, dropout=0.1, norm_layer=nn.BatchNorm2d, align_corners=True): super(SpatialOCR_Module, self).__init__() self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, norm_layer, align_corners) _in_channels = 2 * in_channels self.conv_bn_dropout = nn.Sequential( nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), nn.Dropout2d(dropout) ) def forward(self, feats, proxy_feats): context = self.object_context_block(feats, proxy_feats) output = self.conv_bn_dropout(torch.cat([context, feats], 1)) return output class ObjectAttentionBlock2D(nn.Module): ''' The basic implementation for object context block Input: N X C X H X W Parameters: in_channels : the dimension of the input feature map key_channels : the dimension after the key/query transform scale : choose the scale to downsample the input feature maps (save memory cost) bn_type : specify the bn type Return: N X C X H X W ''' def __init__(self, in_channels, key_channels, scale=1, norm_layer=nn.BatchNorm2d, align_corners=True): super(ObjectAttentionBlock2D, self).__init__() self.scale = scale self.in_channels = in_channels self.key_channels = key_channels self.align_corners = align_corners self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) self.f_pixel = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) ) self.f_object = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) ) self.f_down = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) ) self.f_up = nn.Sequential( nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) ) def forward(self, x, proxy): batch_size, h, w = x.size(0), x.size(2), x.size(3) if self.scale > 1: x = self.pool(x) query = self.f_pixel(x).view(batch_size, self.key_channels, -1) query = query.permute(0, 2, 1) key = self.f_object(proxy).view(batch_size, self.key_channels, -1) value = self.f_down(proxy).view(batch_size, self.key_channels, -1) value = value.permute(0, 2, 1) sim_map = torch.matmul(query, key) sim_map = (self.key_channels ** -.5) * sim_map sim_map = F.softmax(sim_map, dim=-1) # add bg context ... context = torch.matmul(sim_map, value) context = context.permute(0, 2, 1).contiguous() context = context.view(batch_size, self.key_channels, *x.size()[2:]) context = self.f_up(context) if self.scale > 1: context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=self.align_corners) return context ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/resnet.py ================================================ import torch from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s class ResNetBackbone(torch.nn.Module): def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): super(ResNetBackbone, self).__init__() if backbone == 'resnet34': pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) elif backbone == 'resnet50': pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) elif backbone == 'resnet101': pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) elif backbone == 'resnet152': pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) else: raise RuntimeError(f'unknown backbone: {backbone}') self.conv1 = pretrained.conv1 self.bn1 = pretrained.bn1 self.relu = pretrained.relu self.maxpool = pretrained.maxpool self.layer1 = pretrained.layer1 self.layer2 = pretrained.layer2 self.layer3 = pretrained.layer3 self.layer4 = pretrained.layer4 def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) c1 = self.layer1(x) c2 = self.layer2(c1) c3 = self.layer3(c2) c4 = self.layer4(c3) return c1, c2, c3, c4 ================================================ FILE: XMem/inference/interact/fbrs/model/modeling/resnetv1b.py ================================================ import torch import torch.nn as nn GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' class BasicBlockV1b(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(BasicBlockV1b, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) self.bn1 = norm_layer(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=previous_dilation, dilation=previous_dilation, bias=False) self.bn2 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out class BottleneckV1b(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(BottleneckV1b, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = norm_layer(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out class ResNetV1b(nn.Module): """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. Parameters ---------- block : Block Class for the residual block. Options are BasicBlockV1, BottleneckV1. layers : list of int Numbers of layers in each block classes : int, default 1000 Number of classification classes. dilated : bool, default False Applying dilation strategy to pretrained ResNet yielding a stride-8 model, typically used in Semantic Segmentation. norm_layer : object Normalization layer used (default: :class:`nn.BatchNorm2d`) deep_stem : bool, default False Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. avg_down : bool, default False Whether to use average pooling for projection skip connection between stages/downsample. final_drop : float, default 0.0 Dropout ratio before the final classification layer. Reference: - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." """ def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): self.inplanes = stem_width*2 if deep_stem else 64 super(ResNetV1b, self).__init__() if not deep_stem: self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) else: self.conv1 = nn.Sequential( nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), norm_layer(stem_width), nn.ReLU(True), nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), norm_layer(stem_width), nn.ReLU(True), nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(True) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, norm_layer=norm_layer) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, norm_layer=norm_layer) if dilated: self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, avg_down=avg_down, norm_layer=norm_layer) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, avg_down=avg_down, norm_layer=norm_layer) else: self.layer3 = self._make_layer(block, 256, layers[2], stride=2, avg_down=avg_down, norm_layer=norm_layer) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, avg_down=avg_down, norm_layer=norm_layer) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.drop = None if final_drop > 0.0: self.drop = nn.Dropout(final_drop) self.fc = nn.Linear(512 * block.expansion, classes) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, avg_down=False, norm_layer=nn.BatchNorm2d): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = [] if avg_down: if dilation == 1: downsample.append( nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) ) else: downsample.append( nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) ) downsample.extend([ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, kernel_size=1, stride=1, bias=False), norm_layer(planes * block.expansion) ]) downsample = nn.Sequential(*downsample) else: downsample = nn.Sequential( nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, kernel_size=1, stride=stride, bias=False), norm_layer(planes * block.expansion) ) layers = [] if dilation in (1, 2): layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) elif dilation == 4: layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) else: raise RuntimeError("=> unknown dilation size: {}".format(dilation)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) if self.drop is not None: x = self.drop(x) x = self.fc(x) return x def _safe_state_dict_filtering(orig_dict, model_dict_keys): filtered_orig_dict = {} for k, v in orig_dict.items(): if k in model_dict_keys: filtered_orig_dict[k] = v else: print(f"[ERROR] Failed to load <{k}> in backbone") return filtered_orig_dict def resnet34_v1b(pretrained=False, **kwargs): model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), model_dict.keys() ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) return model def resnet50_v1s(pretrained=False, **kwargs): model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), model_dict.keys() ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) return model def resnet101_v1s(pretrained=False, **kwargs): model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), model_dict.keys() ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) return model def resnet152_v1s(pretrained=False, **kwargs): model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), model_dict.keys() ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) return model ================================================ FILE: XMem/inference/interact/fbrs/model/ops.py ================================================ import torch from torch import nn as nn import numpy as np from . import initializer as initializer from ..utils.cython import get_dist_maps def select_activation_function(activation): if isinstance(activation, str): if activation.lower() == 'relu': return nn.ReLU elif activation.lower() == 'softplus': return nn.Softplus else: raise ValueError(f"Unknown activation type {activation}") elif isinstance(activation, nn.Module): return activation else: raise ValueError(f"Unknown activation type {activation}") class BilinearConvTranspose2d(nn.ConvTranspose2d): def __init__(self, in_channels, out_channels, scale, groups=1): kernel_size = 2 * scale - scale % 2 self.scale = scale super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=scale, padding=1, groups=groups, bias=False) self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) class DistMaps(nn.Module): def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False): super(DistMaps, self).__init__() self.spatial_scale = spatial_scale self.norm_radius = norm_radius self.cpu_mode = cpu_mode def get_coord_features(self, points, batchsize, rows, cols): if self.cpu_mode: coords = [] for i in range(batchsize): norm_delimeter = self.spatial_scale * self.norm_radius coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols, norm_delimeter)) coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() else: num_points = points.shape[1] // 2 points = points.view(-1, 2) invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) coord_rows, coord_cols = torch.meshgrid(row_array, col_array) coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) coords.add_(-add_xy) coords.div_(self.norm_radius * self.spatial_scale) coords.mul_(coords) coords[:, 0] += coords[:, 1] coords = coords[:, :1] coords[invalid_points, :, :, :] = 1e6 coords = coords.view(-1, num_points, 1, rows, cols) coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w coords = coords.view(-1, 2, rows, cols) coords.sqrt_().mul_(2).tanh_() return coords def forward(self, x, coords): return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/LICENSE ================================================ MIT License Copyright (c) 2018 Tamaki Kojima Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/README.md ================================================ # pytorch-syncbn Tamaki Kojima(tamakoji@gmail.com) ## Announcement **Pytorch 1.0 support** ## Overview This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) ## Remarks - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` ## Requirements For PyTorch, please refer to https://pytorch.org/ NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04 It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use. ``` sudo apt-get install ninja-build ``` Also install all dependencies for python. For pip, run: ``` pip install -U -r requirements.txt ``` ## Build There is no need to build. just run and JIT will take care. JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes. ## Usage Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` ``` import torch from modules import nn as NN num_gpu = torch.cuda.device_count() model = nn.Sequential( nn.Conv2d(3, 3, 1, 1, bias=False), NN.BatchNorm2d(3), nn.ReLU(inplace=True), nn.Conv2d(3, 3, 1, 1, bias=False), NN.BatchNorm2d(3), ).cuda() model = nn.DataParallel(model, device_ids=range(num_gpu)) x = torch.rand(num_gpu, 3, 2, 2).cuda() z = model(x) ``` ## Math ### Forward 1. compute in each gpu 2. gather all from workers to master and compute where and and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 3. forward batchnorm using global stats by and then where is weight parameter and is bias parameter. 4. save for backward ### Backward 1. Restore saved 2. Compute below sums on each gpu and where then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 3. compute gradients using global stats where and and finally, Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py ================================================ from .syncbn import batchnorm2d_sync ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py ================================================ """ /*****************************************************************************/ Extension module loader code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark /*****************************************************************************/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import glob import os.path import torch try: from torch.utils.cpp_extension import load from torch.utils.cpp_extension import CUDA_HOME except ImportError: raise ImportError( "The cpp layer extensions requires PyTorch 0.4 or higher") def _load_C_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.join(this_dir, "csrc") main_file = glob.glob(os.path.join(this_dir, "*.cpp")) sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) sources = main_file + sources_cpu extra_cflags = [] extra_cuda_cflags = [] if torch.cuda.is_available() and CUDA_HOME is not None: sources.extend(sources_cuda) extra_cflags = ["-O3", "-DWITH_CUDA"] extra_cuda_cflags = ["--expt-extended-lambda"] sources = [os.path.join(this_dir, s) for s in sources] extra_include_paths = [this_dir] return load( name="ext_lib", sources=sources, extra_cflags=extra_cflags, extra_include_paths=extra_include_paths, extra_cuda_cflags=extra_cuda_cflags, ) _backend = _load_C_extensions() ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h ================================================ /***************************************************************************** SyncBN *****************************************************************************/ #pragma once #ifdef WITH_CUDA #include "cuda/ext_lib.h" #endif /// SyncBN std::vector syncbn_sum_sqsum(const at::Tensor& x) { if (x.is_cuda()) { #ifdef WITH_CUDA return syncbn_sum_sqsum_cuda(x); #else AT_ERROR("Not compiled with GPU support"); #endif } else { AT_ERROR("CPU implementation not supported"); } } at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, bool affine, float eps) { if (x.is_cuda()) { #ifdef WITH_CUDA return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); #else AT_ERROR("Not compiled with GPU support"); #endif } else { AT_ERROR("CPU implementation not supported"); } } std::vector syncbn_backward_xhat(const at::Tensor& dz, const at::Tensor& x, const at::Tensor& mean, const at::Tensor& var, float eps) { if (dz.is_cuda()) { #ifdef WITH_CUDA return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); #else AT_ERROR("Not compiled with GPU support"); #endif } else { AT_ERROR("CPU implementation not supported"); } } std::vector syncbn_backward( const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, float eps) { if (dz.is_cuda()) { #ifdef WITH_CUDA return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, affine, eps); #else AT_ERROR("Not compiled with GPU support"); #endif } else { AT_ERROR("CPU implementation not supported"); } } ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu ================================================ /***************************************************************************** CUDA SyncBN code code referenced from : https://github.com/mapillary/inplace_abn *****************************************************************************/ #include #include #include #include #include "cuda/common.h" // Utilities void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) { num = x.size(0); chn = x.size(1); sp = 1; for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i); } /// SyncBN template struct SqSumOp { __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {} __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { T x = tensor[(batch * chn + plane) * sp + n]; return Pair(x, x * x); // x, x^2 } const T *tensor; const int chn; const int sp; }; template __global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum, int num, int chn, int sp) { int plane = blockIdx.x; Pair res = reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp); __syncthreads(); if (threadIdx.x == 0) { sum[plane] = res.v1; sqsum[plane] = res.v2; } } std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) { CHECK_INPUT(x); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); // Prepare output tensors auto sum = at::empty({chn}, x.options()); auto sqsum = at::empty({chn}, x.options()); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); AT_DISPATCH_FLOATING_TYPES( x.type(), "syncbn_sum_sqsum_cuda", ([&] { syncbn_sum_sqsum_kernel<<>>( x.data(), sum.data(), sqsum.data(), num, chn, sp); })); return {sum, sqsum}; } template __global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight, const T *bias, const T *mean, const T *var, bool affine, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _mean = mean[plane]; T _var = var[plane]; T _weight = affine ? weight[plane] : T(1); T _bias = affine ? bias[plane] : T(0); float _invstd = T(0); if (_var || eps) { _invstd = rsqrt(_var + eps); } for (int batch = 0; batch < num; ++batch) { for (int n = threadIdx.x; n < sp; n += blockDim.x) { T _x = x[(batch * chn + plane) * sp + n]; T _xhat = (_x - _mean) * _invstd; T _z = _xhat * _weight + _bias; z[(batch * chn + plane) * sp + n] = _z; } } } at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var, bool affine, float eps) { CHECK_INPUT(x); CHECK_INPUT(weight); CHECK_INPUT(bias); CHECK_INPUT(mean); CHECK_INPUT(var); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); auto z = at::zeros_like(x); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); AT_DISPATCH_FLOATING_TYPES( x.type(), "syncbn_forward_cuda", ([&] { syncbn_forward_kernel<<>>( z.data(), x.data(), weight.data(), bias.data(), mean.data(), var.data(), affine, eps, num, chn, sp); })); return z; } template struct XHatOp { __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s) : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {} __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { // xhat = (x - bias) * weight T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight; // dxhat * x_hat T _dz = dz[(batch * chn + plane) * sp + n]; return Pair(_dz, _dz * _xhat); } const T weight; const T bias; const T *dz; const T *x; const int chn; const int sp; }; template __global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x, const T *mean, const T *var, T *sum_dz, T *sum_dz_xhat, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _mean = mean[plane]; T _var = var[plane]; T _invstd = T(0); if (_var || eps) { _invstd = rsqrt(_var + eps); } Pair res = reduce, XHatOp>( XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp); __syncthreads(); if (threadIdx.x == 0) { // \sum(\frac{dJ}{dy_i}) sum_dz[plane] = res.v1; // \sum(\frac{dJ}{dy_i}*\hat{x_i}) sum_dz_xhat[plane] = res.v2; } } std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mean, const at::Tensor &var, float eps) { CHECK_INPUT(dz); CHECK_INPUT(x); CHECK_INPUT(mean); CHECK_INPUT(var); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); // Prepare output tensors auto sum_dz = at::empty({chn}, x.options()); auto sum_dz_xhat = at::empty({chn}, x.options()); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); AT_DISPATCH_FLOATING_TYPES( x.type(), "syncbn_backward_xhat_cuda", ([&] { syncbn_backward_xhat_kernel<<>>( dz.data(), x.data(), mean.data(), var.data(), sum_dz.data(), sum_dz_xhat.data(), eps, num, chn, sp); })); return {sum_dz, sum_dz_xhat}; } template __global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight, const T *bias, const T *mean, const T *var, const T *sum_dz, const T *sum_dz_xhat, T *dx, T *dweight, T *dbias, bool affine, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _mean = mean[plane]; T _var = var[plane]; T _weight = affine ? weight[plane] : T(1); T _sum_dz = sum_dz[plane]; T _sum_dz_xhat = sum_dz_xhat[plane]; T _invstd = T(0); if (_var || eps) { _invstd = rsqrt(_var + eps); } /* \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} ( N\frac{dJ}{d\hat{x_i}} - \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) - \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}) ) Note : N is omitted here since it will be accumulated and _sum_dz and _sum_dz_xhat expected to be already normalized before the call. */ if (dx) { T _mul = _weight * _invstd; for (int batch = 0; batch < num; ++batch) { for (int n = threadIdx.x; n < sp; n += blockDim.x) { T _dz = dz[(batch * chn + plane) * sp + n]; T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd; T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul; dx[(batch * chn + plane) * sp + n] = _dx; } } } __syncthreads(); if (threadIdx.x == 0) { if (affine) { T _norm = num * sp; dweight[plane] += _sum_dz_xhat * _norm; dbias[plane] += _sum_dz * _norm; } } } std::vector syncbn_backward_cuda( const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var, const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine, float eps) { CHECK_INPUT(dz); CHECK_INPUT(x); CHECK_INPUT(weight); CHECK_INPUT(bias); CHECK_INPUT(mean); CHECK_INPUT(var); CHECK_INPUT(sum_dz); CHECK_INPUT(sum_dz_xhat); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); // Prepare output tensors auto dx = at::zeros_like(dz); auto dweight = at::zeros_like(weight); auto dbias = at::zeros_like(bias); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); AT_DISPATCH_FLOATING_TYPES( x.type(), "syncbn_backward_cuda", ([&] { syncbn_backward_kernel<<>>( dz.data(), x.data(), weight.data(), bias.data(), mean.data(), var.data(), sum_dz.data(), sum_dz_xhat.data(), dx.data(), dweight.data(), dbias.data(), affine, eps, num, chn, sp); })); return {dx, dweight, dbias}; } ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h ================================================ /***************************************************************************** CUDA utility funcs code referenced from : https://github.com/mapillary/inplace_abn *****************************************************************************/ #pragma once #include // Checks #ifndef AT_CHECK #define AT_CHECK AT_ASSERT #endif #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) /* * General settings */ const int WARP_SIZE = 32; const int MAX_BLOCK_SIZE = 512; template struct Pair { T v1, v2; __device__ Pair() {} __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} __device__ Pair(T v) : v1(v), v2(v) {} __device__ Pair(int v) : v1(v), v2(v) {} __device__ Pair &operator+=(const Pair &a) { v1 += a.v1; v2 += a.v2; return *this; } }; /* * Utility functions */ template __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDART_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); #endif } __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } static int getNumThreads(int nElem) { int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; for (int i = 0; i != 5; ++i) { if (nElem <= threadSizes[i]) { return threadSizes[i]; } } return MAX_BLOCK_SIZE; } template static __device__ __forceinline__ T warpSum(T val) { #if __CUDA_ARCH__ >= 300 for (int i = 0; i < getMSB(WARP_SIZE); ++i) { val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); } #else __shared__ T values[MAX_BLOCK_SIZE]; values[threadIdx.x] = val; __threadfence_block(); const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; for (int i = 1; i < WARP_SIZE; i++) { val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; } #endif return val; } template static __device__ __forceinline__ Pair warpSum(Pair value) { value.v1 = warpSum(value.v1); value.v2 = warpSum(value.v2); return value; } template __device__ T reduce(Op op, int plane, int N, int C, int S) { T sum = (T)0; for (int batch = 0; batch < N; ++batch) { for (int x = threadIdx.x; x < S; x += blockDim.x) { sum += op(batch, plane, x); } } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { shared[threadIdx.x / WARP_SIZE] = sum; } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T)0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h ================================================ /***************************************************************************** CUDA SyncBN code *****************************************************************************/ #pragma once #include #include /// Sync-BN std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, bool affine, float eps); std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, const at::Tensor& x, const at::Tensor& mean, const at::Tensor& var, float eps); std::vector syncbn_backward_cuda( const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, float eps); ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp ================================================ #include "bn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); m.def("syncbn_backward_xhat", &syncbn_backward_xhat, "First part of SyncBN backward computation"); m.def("syncbn_backward", &syncbn_backward, "Second part of SyncBN backward computation"); } ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py ================================================ """ /*****************************************************************************/ BatchNorm2dSync with multi-gpu code referenced from : https://github.com/mapillary/inplace_abn /*****************************************************************************/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch.cuda.comm as comm from torch.autograd import Function from torch.autograd.function import once_differentiable from ._csrc import _backend def _count_samples(x): count = 1 for i, s in enumerate(x.size()): if i != 1: count *= s return count class BatchNorm2dSyncFunc(Function): @staticmethod def forward(ctx, x, weight, bias, running_mean, running_var, extra, compute_stats=True, momentum=0.1, eps=1e-05): def _parse_extra(ctx, extra): ctx.is_master = extra["is_master"] if ctx.is_master: ctx.master_queue = extra["master_queue"] ctx.worker_queues = extra["worker_queues"] ctx.worker_ids = extra["worker_ids"] else: ctx.master_queue = extra["master_queue"] ctx.worker_queue = extra["worker_queue"] # Save context if extra is not None: _parse_extra(ctx, extra) ctx.compute_stats = compute_stats ctx.momentum = momentum ctx.eps = eps ctx.affine = weight is not None and bias is not None if ctx.compute_stats: N = _count_samples(x) * (ctx.master_queue.maxsize + 1) assert N > 1 # 1. compute sum(x) and sum(x^2) xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) if ctx.is_master: xsums, xsqsums = [xsum], [xsqsum] # master : gatther all sum(x) and sum(x^2) from slaves for _ in range(ctx.master_queue.maxsize): xsum_w, xsqsum_w = ctx.master_queue.get() ctx.master_queue.task_done() xsums.append(xsum_w) xsqsums.append(xsqsum_w) xsum = comm.reduce_add(xsums) xsqsum = comm.reduce_add(xsqsums) mean = xsum / N sumvar = xsqsum - xsum * mean var = sumvar / N uvar = sumvar / (N - 1) # master : broadcast global mean, variance to all slaves tensors = comm.broadcast_coalesced( (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: # slave : send sum(x) and sum(x^2) to master ctx.master_queue.put((xsum, xsqsum)) # slave : get global mean and variance mean, uvar, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) ctx.N = N ctx.save_for_backward(x, weight, bias, mean, var) else: mean, var = running_mean, running_var # do batch norm forward z = _backend.syncbn_forward(x, weight, bias, mean, var, ctx.affine, ctx.eps) return z @staticmethod @once_differentiable def backward(ctx, dz): x, weight, bias, mean, var = ctx.saved_tensors dz = dz.contiguous() # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat( dz, x, mean, var, ctx.eps) if ctx.is_master: sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] # master : gatther from slaves for _ in range(ctx.master_queue.maxsize): sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() ctx.master_queue.task_done() sum_dzs.append(sum_dz_w) sum_dz_xhats.append(sum_dz_xhat_w) # master : compute global stats sum_dz = comm.reduce_add(sum_dzs) sum_dz_xhat = comm.reduce_add(sum_dz_xhats) sum_dz /= ctx.N sum_dz_xhat /= ctx.N # master : broadcast global stats tensors = comm.broadcast_coalesced( (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: # slave : send to master ctx.master_queue.put((sum_dz, sum_dz_xhat)) # slave : get global stats sum_dz, sum_dz_xhat = ctx.worker_queue.get() ctx.worker_queue.task_done() # do batch norm backward dx, dweight, dbias = _backend.syncbn_backward( dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, ctx.affine, ctx.eps) return dx, dweight, dbias, \ None, None, None, None, None, None batchnorm2d_sync = BatchNorm2dSyncFunc.apply __all__ = ["batchnorm2d_sync"] ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py ================================================ from .syncbn import * ================================================ FILE: XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py ================================================ """ /*****************************************************************************/ BatchNorm2dSync with multi-gpu /*****************************************************************************/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function try: # python 3 from queue import Queue except ImportError: # python 2 from Queue import Queue import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter from isegm.model.syncbn.modules.functional import batchnorm2d_sync class _BatchNorm(nn.Module): """ Customized BatchNorm from nn.BatchNorm >> added freeze attribute to enable bn freeze. """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_BatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats self.freezed = False if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.reset_parameters() def reset_parameters(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) if self.affine: self.weight.data.uniform_() self.bias.data.zero_() def _check_input_dim(self, input): return NotImplemented def forward(self, input): self._check_input_dim(input) compute_stats = not self.freezed and \ self.training and self.track_running_stats ret = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, compute_stats, self.momentum, self.eps) return ret def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, '\ 'affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format( **self.__dict__) class BatchNorm2dNoSync(_BatchNorm): """ Equivalent to nn.BatchNorm2d """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) class BatchNorm2dSync(BatchNorm2dNoSync): """ BatchNorm2d with automatic multi-GPU Sync """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm2dSync, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.sync_enabled = True self.devices = list(range(torch.cuda.device_count())) if len(self.devices) > 1: # Initialize queues self.worker_ids = self.devices[1:] self.master_queue = Queue(len(self.worker_ids)) self.worker_queues = [Queue(1) for _ in self.worker_ids] def forward(self, x): compute_stats = not self.freezed and \ self.training and self.track_running_stats if self.sync_enabled and compute_stats and len(self.devices) > 1: if x.get_device() == self.devices[0]: # Master mode extra = { "is_master": True, "master_queue": self.master_queue, "worker_queues": self.worker_queues, "worker_ids": self.worker_ids } else: # Worker mode extra = { "is_master": False, "master_queue": self.master_queue, "worker_queue": self.worker_queues[ self.worker_ids.index(x.get_device())] } return batchnorm2d_sync(x, self.weight, self.bias, self.running_mean, self.running_var, extra, compute_stats, self.momentum, self.eps) return super(BatchNorm2dSync, self).forward(x) def __repr__(self): """repr""" rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 'affine={affine}, ' \ 'track_running_stats={track_running_stats},' \ 'devices={devices})' return rep.format(name=self.__class__.__name__, **self.__dict__) #BatchNorm2d = BatchNorm2dNoSync BatchNorm2d = BatchNorm2dSync ================================================ FILE: XMem/inference/interact/fbrs/utils/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/fbrs/utils/cython/__init__.py ================================================ # noinspection PyUnresolvedReferences from .dist_maps import get_dist_maps ================================================ FILE: XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx ================================================ import numpy as np cimport cython cimport numpy as np from libc.stdlib cimport malloc, free ctypedef struct qnode: int row int col int layer int orig_row int orig_col @cython.infer_types(True) @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, int height, int width, float norm_delimeter): cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ np.full((2, height, width), 1e6, dtype=np.float32, order="C") cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] cdef int i, j, x, y, dx, dy cdef qnode v cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) cdef int qhead = 0, qtail = -1 cdef float ndist for i in range(points.shape[0]): x, y = round(points[i, 0]), round(points[i, 1]) if x >= 0: qtail += 1 q[qtail].row = x q[qtail].col = y q[qtail].orig_row = x q[qtail].orig_col = y if i >= points.shape[0] / 2: q[qtail].layer = 1 else: q[qtail].layer = 0 dist_maps[q[qtail].layer, x, y] = 0 while qtail - qhead + 1 > 0: v = q[qhead] qhead += 1 for k in range(4): x = v.row + dxy[2 * k] y = v.col + dxy[2 * k + 1] ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 if (x >= 0 and y >= 0 and x < height and y < width and dist_maps[v.layer, x, y] > ndist): qtail += 1 q[qtail].orig_col = v.orig_col q[qtail].orig_row = v.orig_row q[qtail].layer = v.layer q[qtail].row = x q[qtail].col = y dist_maps[v.layer, x, y] = ndist free(q) return dist_maps ================================================ FILE: XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld ================================================ import numpy def make_ext(modname, pyxfilename): from distutils.extension import Extension return Extension(modname, [pyxfilename], include_dirs=[numpy.get_include()], extra_compile_args=['-O3'], language='c++') ================================================ FILE: XMem/inference/interact/fbrs/utils/cython/dist_maps.py ================================================ import pyximport; pyximport.install(pyximport=True, language_level=3) # noinspection PyUnresolvedReferences from ._get_dist_maps import get_dist_maps ================================================ FILE: XMem/inference/interact/fbrs/utils/misc.py ================================================ from functools import partial import torch import numpy as np def get_dims_with_exclusion(dim, exclude=None): dims = list(range(dim)) if exclude is not None: dims.remove(exclude) return dims def get_unique_labels(mask): return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1 def get_bbox_from_mask(mask): rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] return rmin, rmax, cmin, cmax def expand_bbox(bbox, expand_ratio, min_crop_size=None): rmin, rmax, cmin, cmax = bbox rcenter = 0.5 * (rmin + rmax) ccenter = 0.5 * (cmin + cmax) height = expand_ratio * (rmax - rmin + 1) width = expand_ratio * (cmax - cmin + 1) if min_crop_size is not None: height = max(height, min_crop_size) width = max(width, min_crop_size) rmin = int(round(rcenter - 0.5 * height)) rmax = int(round(rcenter + 0.5 * height)) cmin = int(round(ccenter - 0.5 * width)) cmax = int(round(ccenter + 0.5 * width)) return rmin, rmax, cmin, cmax def clamp_bbox(bbox, rmin, rmax, cmin, cmax): return (max(rmin, bbox[0]), min(rmax, bbox[1]), max(cmin, bbox[2]), min(cmax, bbox[3])) def get_bbox_iou(b1, b2): h_iou = get_segments_iou(b1[:2], b2[:2]) w_iou = get_segments_iou(b1[2:4], b2[2:4]) return h_iou * w_iou def get_segments_iou(s1, s2): a, b = s1 c, d = s2 intersection = max(0, min(b, d) - max(a, c) + 1) union = max(1e-6, max(b, d) - min(a, c) + 1) return intersection / union ================================================ FILE: XMem/inference/interact/fbrs/utils/vis.py ================================================ from functools import lru_cache import cv2 import numpy as np def visualize_instances(imask, bg_color=255, boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): num_objects = imask.max() + 1 palette = get_palette(num_objects) if bg_color is not None: palette[0] = bg_color result = palette[imask].astype(np.uint8) if boundaries_color is not None: boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) tresult = result.astype(np.float32) tresult[boundaries_mask] = boundaries_color tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result result = tresult.astype(np.uint8) return result @lru_cache(maxsize=16) def get_palette(num_cls): palette = np.zeros(3 * num_cls, dtype=np.int32) for j in range(0, num_cls): lab = j i = 0 while lab > 0: palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) i = i + 1 lab >>= 3 return palette.reshape((-1, 3)) def visualize_mask(mask, num_cls): palette = get_palette(num_cls) mask[mask == -1] = 0 return palette[mask].astype(np.uint8) def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): proposal_map, colors, candidates = proposals_info proposal_map = draw_probmap(proposal_map) for x, y in candidates: proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) return proposal_map def draw_probmap(x): return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) def draw_points(image, points, color, radius=3): image = image.copy() for p in points: image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1) return image def draw_instance_map(x, palette=None): num_colors = x.max() + 1 if palette is None: palette = get_palette(num_colors) return palette[x].astype(np.uint8) def blend_mask(image, mask, alpha=0.6): if mask.min() == -1: mask = mask.copy() + 1 imap = draw_instance_map(mask) result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) return result def get_boundaries(instances_masks, boundaries_width=1): boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) for obj_id in np.unique(instances_masks.flatten()): if obj_id == 0: continue obj_mask = instances_masks == obj_id kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) boundaries = np.logical_or(boundaries, obj_boundary) return boundaries def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), neg_color=(255, 0, 0), radius=4): result = img.copy() if mask is not None: palette = get_palette(np.max(mask) + 1) rgb_mask = palette[mask.astype(np.uint8)] mask_region = (mask > 0).astype(np.uint8) result = result * (1 - mask_region[:, :, np.newaxis]) + \ (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ alpha * rgb_mask result = result.astype(np.uint8) # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) if clicks_list is not None and len(clicks_list) > 0: pos_points = [click.coords for click in clicks_list if click.is_positive] neg_points = [click.coords for click in clicks_list if not click.is_positive] result = draw_points(result, pos_points, pos_color, radius=radius) result = draw_points(result, neg_points, neg_color, radius=radius) return result ================================================ FILE: XMem/inference/interact/fbrs_controller.py ================================================ import torch from .fbrs.controller import InteractiveController from .fbrs.inference import utils class FBRSController: def __init__(self, checkpoint_path, device='cuda:0', max_size=800): model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260) # Predictor params zoomin_params = { 'skip_clicks': 1, 'target_size': 480, 'expansion_ratio': 1.4, } predictor_params = { 'brs_mode': 'f-BRS-B', 'prob_thresh': 0.5, 'zoom_in_params': zoomin_params, 'predictor_params': { 'net_clicks_limit': 8, 'max_size': 800, }, 'brs_opt_func_params': {'min_iou_diff': 1e-3}, 'lbfgs_params': {'maxfun': 20} } self.controller = InteractiveController(model, device, predictor_params) self.anchored = False self.device = device def unanchor(self): self.anchored = False def interact(self, image, x, y, is_positive): image = image.to(self.device, non_blocking=True) if not self.anchored: self.controller.set_image(image) self.controller.reset_predictor() self.anchored = True self.controller.add_click(x, y, is_positive) # return self.controller.result_mask # return self.controller.probs_history[-1][1] return (self.controller.probs_history[-1][1]>0.5).float() def undo(self): self.controller.undo_click() if len(self.controller.probs_history) == 0: return None else: return (self.controller.probs_history[-1][1]>0.5).float() ================================================ FILE: XMem/inference/interact/gui.py ================================================ """ Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN (which is based on https://github.com/seoungwugoh/ivs-demo) This version is much simplified. In this repo, we don't have - local control - fusion module - undo - timers but with XMem as the backbone and is more memory (for both CPU and GPU) friendly """ import functools import os import cv2 # fix conflicts between qt5 and cv2 os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH") import numpy as np import torch try: from torch import mps except: print('torch.MPS not available.') from PySide6.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QRadioButton) from PySide6.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon, QShortcut from PySide6.QtCore import Qt, QTimer from model.network import XMem from inference.inference_core import InferenceCore from .s2m_controller import S2MController from .fbrs_controller import FBRSController from .interactive_utils import * from .interaction import * from .resource_manager import ResourceManager from .gui_utils import * class App(QWidget): def __init__(self, net: XMem, resource_manager: ResourceManager, s2m_ctrl:S2MController, fbrs_ctrl:FBRSController, config, device): super().__init__() self.initialized = False self.num_objects = config['num_objects'] self.s2m_controller = s2m_ctrl self.fbrs_controller = fbrs_ctrl self.config = config self.processor = InferenceCore(net, config) self.processor.set_all_labels(list(range(1, self.num_objects+1))) self.res_man = resource_manager self.device = device self.num_frames = len(self.res_man) self.height, self.width = self.res_man.h, self.res_man.w # set window self.setWindowTitle('XMem Demo') self.setGeometry(100, 100, self.width, self.height+100) self.setWindowIcon(QIcon('docs/icon.png')) # some buttons self.play_button = QPushButton('Play Video') self.play_button.clicked.connect(self.on_play_video) self.commit_button = QPushButton('Commit') self.commit_button.clicked.connect(self.on_commit) self.export_button = QPushButton('Export Overlays as Video') self.export_button.clicked.connect(self.on_export_visualization) self.forward_run_button = QPushButton('Forward Propagate') self.forward_run_button.clicked.connect(self.on_forward_propagation) self.forward_run_button.setMinimumWidth(150) self.backward_run_button = QPushButton('Backward Propagate') self.backward_run_button.clicked.connect(self.on_backward_propagation) self.backward_run_button.setMinimumWidth(150) self.reset_button = QPushButton('Reset Frame') self.reset_button.clicked.connect(self.on_reset_mask) # LCD self.lcd = QTextEdit() self.lcd.setReadOnly(True) self.lcd.setMaximumHeight(28) self.lcd.setMaximumWidth(120) self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) # Current Mask LCD self.object_dial = QSpinBox() self.object_dial.setReadOnly(False) self.object_dial.setMaximumHeight(28) self.object_dial.setMaximumWidth(56) self.object_dial.setMinimum(1) self.object_dial.setMaximum(self.num_objects) self.object_dial.editingFinished.connect(self.on_object_dial_change) # timeline slider self.tl_slider = QSlider(Qt.Orientation.Horizontal) self.tl_slider.valueChanged.connect(self.tl_slide) self.tl_slider.setMinimum(0) self.tl_slider.setMaximum(self.num_frames-1) self.tl_slider.setValue(0) self.tl_slider.setTickPosition(QSlider.TickPosition.TicksBelow) self.tl_slider.setTickInterval(1) # brush size slider self.brush_label = QLabel() self.brush_label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.brush_label.setMinimumWidth(150) self.brush_slider = QSlider(Qt.Orientation.Horizontal) self.brush_slider.valueChanged.connect(self.brush_slide) self.brush_slider.setMinimum(1) self.brush_slider.setMaximum(100) self.brush_slider.setValue(3) self.brush_slider.setTickPosition(QSlider.TickPosition.TicksBelow) self.brush_slider.setTickInterval(2) self.brush_slider.setMinimumWidth(300) # combobox self.combo = QComboBox(self) self.combo.addItem("davis") self.combo.addItem("fade") self.combo.addItem("light") self.combo.addItem("popup") self.combo.addItem("layered") self.combo.currentTextChanged.connect(self.set_viz_mode) self.save_visualization_checkbox = QCheckBox(self) self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle) self.save_visualization_checkbox.setChecked(False) self.save_visualization = False # Radio buttons for type of interactions self.curr_interaction = 'Click' self.interaction_group = QButtonGroup() self.radio_fbrs = QRadioButton('Click') self.radio_s2m = QRadioButton('Scribble') self.radio_free = QRadioButton('Free') self.interaction_group.addButton(self.radio_fbrs) self.interaction_group.addButton(self.radio_s2m) self.interaction_group.addButton(self.radio_free) self.radio_fbrs.toggled.connect(self.interaction_radio_clicked) self.radio_s2m.toggled.connect(self.interaction_radio_clicked) self.radio_free.toggled.connect(self.interaction_radio_clicked) self.radio_fbrs.toggle() # Main canvas -> QLabel self.main_canvas = QLabel() self.main_canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) self.main_canvas.setAlignment(Qt.AlignmentFlag.AlignCenter) self.main_canvas.setMinimumSize(100, 100) self.main_canvas.mousePressEvent = self.on_mouse_press self.main_canvas.mouseMoveEvent = self.on_mouse_motion self.main_canvas.setMouseTracking(True) # Required for all-time tracking self.main_canvas.mouseReleaseEvent = self.on_mouse_release # Minimap -> Also a QLabel self.minimap = QLabel() self.minimap.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) self.minimap.setAlignment(Qt.AlignmentFlag.AlignTop) self.minimap.setMinimumSize(100, 100) # Zoom-in buttons self.zoom_p_button = QPushButton('Zoom +') self.zoom_p_button.clicked.connect(self.on_zoom_plus) self.zoom_m_button = QPushButton('Zoom -') self.zoom_m_button.clicked.connect(self.on_zoom_minus) # Parameters setting self.clear_mem_button = QPushButton('Clear memory') self.clear_mem_button.clicked.connect(self.on_clear_memory) self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size') self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)') self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)') self.update_memory_size() self.update_gpu_usage() self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames', callback=self.on_work_min_change) self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames', callback=self.on_work_max_change) self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000, 'Max. long-term memory size', step=1000, callback=self.update_config) self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280, 'Number of prototypes', step=32, callback=self.update_config) self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)', callback=self.update_config) self.work_mem_min.setValue(self.processor.memory.min_mt_frames) self.work_mem_max.setValue(self.processor.memory.max_mt_frames) self.long_mem_max.setValue(self.processor.memory.max_long_elements) self.num_prototypes_box.setValue(self.processor.memory.num_prototypes) self.mem_every_box.setValue(self.processor.mem_every) # import mask/layer self.import_mask_button = QPushButton('Import mask') self.import_mask_button.clicked.connect(self.on_import_mask) self.import_layer_button = QPushButton('Import layer') self.import_layer_button.clicked.connect(self.on_import_layer) # Console on the GUI self.console = QPlainTextEdit() self.console.setReadOnly(True) self.console.setMinimumHeight(100) self.console.setMaximumHeight(100) # navigator navi = QHBoxLayout() interact_subbox = QVBoxLayout() interact_topbox = QHBoxLayout() interact_botbox = QHBoxLayout() interact_topbox.setAlignment(Qt.AlignmentFlag.AlignCenter) interact_topbox.addWidget(self.lcd) interact_topbox.addWidget(self.play_button) interact_topbox.addWidget(self.radio_s2m) interact_topbox.addWidget(self.radio_fbrs) interact_topbox.addWidget(self.radio_free) interact_topbox.addWidget(self.reset_button) interact_botbox.addWidget(QLabel('Current Object ID:')) interact_botbox.addWidget(self.object_dial) interact_botbox.addWidget(self.brush_label) interact_botbox.addWidget(self.brush_slider) interact_subbox.addLayout(interact_topbox) interact_subbox.addLayout(interact_botbox) navi.addLayout(interact_subbox) apply_fixed_size_policy = lambda x: x.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) apply_to_all_children_widget(interact_topbox, apply_fixed_size_policy) apply_to_all_children_widget(interact_botbox, apply_fixed_size_policy) navi.addStretch(1) navi.addStretch(1) overlay_subbox = QVBoxLayout() overlay_topbox = QHBoxLayout() overlay_botbox = QHBoxLayout() overlay_botbox.setAlignment(Qt.AlignmentFlag.AlignRight) overlay_topbox.addWidget(QLabel('Overlay Mode')) overlay_topbox.addWidget(self.combo) overlay_topbox.addWidget(QLabel('Save overlay during propagation')) overlay_topbox.addWidget(self.save_visualization_checkbox) overlay_botbox.addWidget(self.export_button) overlay_subbox.addLayout(overlay_topbox) overlay_subbox.addLayout(overlay_botbox) navi.addLayout(overlay_subbox) apply_to_all_children_widget(overlay_topbox, apply_fixed_size_policy) apply_to_all_children_widget(overlay_botbox, apply_fixed_size_policy) navi.addStretch(1) navi.addWidget(self.commit_button) navi.addWidget(self.forward_run_button) navi.addWidget(self.backward_run_button) # Drawing area, main canvas and minimap draw_area = QHBoxLayout() draw_area.addWidget(self.main_canvas, 4) # Minimap area minimap_area = QVBoxLayout() minimap_area.setAlignment(Qt.AlignmentFlag.AlignTop) mini_label = QLabel('Minimap') mini_label.setAlignment(Qt.AlignmentFlag.AlignTop) minimap_area.addWidget(mini_label) # Minimap zooming minimap_ctrl = QHBoxLayout() minimap_ctrl.setAlignment(Qt.AlignmentFlag.AlignTop) minimap_ctrl.addWidget(self.zoom_p_button) minimap_ctrl.addWidget(self.zoom_m_button) minimap_area.addLayout(minimap_ctrl) minimap_area.addWidget(self.minimap) # Parameters minimap_area.addLayout(self.work_mem_gauge_layout) minimap_area.addLayout(self.long_mem_gauge_layout) minimap_area.addLayout(self.gpu_mem_gauge_layout) minimap_area.addLayout(self.torch_mem_gauge_layout) minimap_area.addWidget(self.clear_mem_button) minimap_area.addLayout(self.work_mem_min_layout) minimap_area.addLayout(self.work_mem_max_layout) minimap_area.addLayout(self.long_mem_max_layout) minimap_area.addLayout(self.num_prototypes_box_layout) minimap_area.addLayout(self.mem_every_box_layout) # import mask/layer import_area = QHBoxLayout() import_area.setAlignment(Qt.AlignmentFlag.AlignTop) import_area.addWidget(self.import_mask_button) import_area.addWidget(self.import_layer_button) minimap_area.addLayout(import_area) # console minimap_area.addWidget(self.console) draw_area.addLayout(minimap_area, 1) layout = QVBoxLayout() layout.addLayout(draw_area) layout.addWidget(self.tl_slider) layout.addLayout(navi) self.setLayout(layout) # timer to play video self.timer = QTimer() self.timer.setSingleShot(False) self.timer.timeout.connect(self.on_play_video_timer) # timer to update GPU usage self.gpu_timer = QTimer() self.gpu_timer.setSingleShot(False) self.gpu_timer.timeout.connect(self.on_gpu_timer) self.gpu_timer.setInterval(2000) self.gpu_timer.start() # current frame info self.curr_frame_dirty = False self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8) self.current_image_torch = None self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8) self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).to(self.device) # initialize visualization self.viz_mode = 'davis' self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) self.cursur = 0 self.on_showing = None # Zoom parameters self.zoom_pixels = 150 # initialize action self.interaction = None self.pressed = False self.right_click = False self.current_object = 1 self.last_ex = self.last_ey = 0 self.propagating = False # Objects shortcuts for i in range(1, self.num_objects+1): QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i)) QShortcut(QKeySequence(f"Ctrl+{i}"), self).activated.connect(functools.partial(self.hit_number_key, i)) # <- and -> shortcuts QShortcut(QKeySequence(Qt.Key.Key_Left), self).activated.connect(self.on_prev_frame) QShortcut(QKeySequence(Qt.Key.Key_Right), self).activated.connect(self.on_next_frame) self.interacted_prob = None self.overlay_layer = None self.overlay_layer_torch = None # the object id used for popup/layered overlay self.vis_target_objects = [1] # try to load the default overlay self._try_load_layer('./docs/ECCV-logo.png') self.load_current_image_mask() self.show_current_frame() self.show() self.console_push_text('Initialized.') self.initialized = True def resizeEvent(self, event): self.show_current_frame() def console_push_text(self, text): self.console.moveCursor(QTextCursor.MoveOperation.End) self.console.insertPlainText(text+'\n') def interaction_radio_clicked(self, event): self.last_interaction = self.curr_interaction if self.radio_s2m.isChecked(): self.curr_interaction = 'Scribble' self.brush_size = 3 self.brush_slider.setDisabled(True) elif self.radio_fbrs.isChecked(): self.curr_interaction = 'Click' self.brush_size = 3 self.brush_slider.setDisabled(True) elif self.radio_free.isChecked(): self.brush_slider.setDisabled(False) self.brush_slide() self.curr_interaction = 'Free' if self.curr_interaction == 'Scribble': self.commit_button.setEnabled(True) else: self.commit_button.setEnabled(False) def load_current_image_mask(self, no_mask=False): self.current_image = self.res_man.get_image(self.cursur) self.current_image_torch = None if not no_mask: loaded_mask = self.res_man.get_mask(self.cursur) if loaded_mask is None: self.current_mask.fill(0) else: self.current_mask = loaded_mask.copy() self.current_prob = None def load_current_torch_image_mask(self, no_mask=False): if self.current_image_torch is None: self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image, self.device) if self.current_prob is None and not no_mask: self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).to(self.device) def compose_current_im(self): self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask, self.overlay_layer, self.vis_target_objects) def update_interact_vis(self): # Update the interactions without re-computing the overlay height, width, channel = self.viz.shape bytesPerLine = 3 * width vis_map = self.vis_map vis_alpha = self.vis_alpha brush_vis_map = self.brush_vis_map brush_vis_alpha = self.brush_vis_alpha self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8) qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format.Format_RGB888) self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.FastTransformation))) self.main_canvas_size = self.main_canvas.size() self.image_size = qImg.size() def update_minimap(self): ex, ey = self.last_ex, self.last_ey r = self.zoom_pixels//2 ex = int(round(max(r, min(self.width-r, ex)))) ey = int(round(max(r, min(self.height-r, ey)))) patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8) height, width, channel = patch.shape bytesPerLine = 3 * width qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format.Format_RGB888) self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.FastTransformation))) def update_current_image_fast(self): # fast path, uses gpu. Changes the image in-place to avoid copying self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm, self.current_prob, self.overlay_layer_torch, self.vis_target_objects) if self.save_visualization: self.res_man.save_visualization(self.cursur, self.viz) height, width, channel = self.viz.shape bytesPerLine = 3 * width qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format.Format_RGB888) self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.FastTransformation))) def show_current_frame(self, fast=False): # Re-compute overlay and show the image if fast: self.update_current_image_fast() else: self.compose_current_im() self.update_interact_vis() self.update_minimap() self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) self.tl_slider.setValue(self.cursur) def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates oh, ow = self.image_size.height(), self.image_size.width() nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width() h_ratio = nh/oh w_ratio = nw/ow dominate_ratio = min(h_ratio, w_ratio) # Solve scale x /= dominate_ratio y /= dominate_ratio # Solve padding fh, fw = nh/dominate_ratio, nw/dominate_ratio x -= (fw-ow)/2 y -= (fh-oh)/2 return x, y def is_pos_out_of_bound(self, x, y): x, y = self.pixel_pos_to_image_pos(x, y) out_of_bound = ( (x < 0) or (y < 0) or (x > self.width-1) or (y > self.height-1) ) return out_of_bound def get_scaled_pos(self, x, y): x, y = self.pixel_pos_to_image_pos(x, y) x = max(0, min(self.width-1, x)) y = max(0, min(self.height-1, y)) return x, y def clear_visualization(self): self.vis_map.fill(0) self.vis_alpha.fill(0) def reset_this_interaction(self): self.complete_interaction() self.clear_visualization() self.interaction = None if self.fbrs_controller is not None: self.fbrs_controller.unanchor() def set_viz_mode(self): self.viz_mode = self.combo.currentText() self.show_current_frame() def save_current_mask(self): # save mask to hard disk self.res_man.save_mask(self.cursur, self.current_mask) def tl_slide(self): # if we are propagating, the on_run function will take care of everything # don't do duplicate work here if not self.propagating: if self.curr_frame_dirty: self.save_current_mask() self.curr_frame_dirty = False self.reset_this_interaction() self.cursur = self.tl_slider.value() self.load_current_image_mask() self.show_current_frame() def brush_slide(self): self.brush_size = self.brush_slider.value() self.brush_label.setText('Brush size (in free mode): %d' % self.brush_size) try: if type(self.interaction) == FreeInteraction: self.interaction.set_size(self.brush_size) except AttributeError: # Initialization, forget about it pass def on_forward_propagation(self): if self.propagating: # acts as a pause button self.propagating = False else: self.propagate_fn = self.on_next_frame self.backward_run_button.setEnabled(False) self.forward_run_button.setText('Pause Propagation') self.on_propagation() def on_backward_propagation(self): if self.propagating: # acts as a pause button self.propagating = False else: self.propagate_fn = self.on_prev_frame self.forward_run_button.setEnabled(False) self.backward_run_button.setText('Pause Propagation') self.on_propagation() def on_pause(self): self.propagating = False self.forward_run_button.setEnabled(True) self.backward_run_button.setEnabled(True) self.clear_mem_button.setEnabled(True) self.forward_run_button.setText('Forward Propagate') self.backward_run_button.setText('Backward Propagate') self.console_push_text('Propagation stopped.') def on_propagation(self): # start to propagate self.load_current_torch_image_mask() self.show_current_frame(fast=True) self.console_push_text('Propagation started.') self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:]) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) # clear self.interacted_prob = None self.reset_this_interaction() self.propagating = True self.clear_mem_button.setEnabled(False) # propagate till the end while self.propagating: self.propagate_fn() self.load_current_image_mask(no_mask=True) self.load_current_torch_image_mask(no_mask=True) self.current_prob = self.processor.step(self.current_image_torch) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) self.save_current_mask() self.show_current_frame(fast=True) self.update_memory_size() QApplication.processEvents() if self.cursur == 0 or self.cursur == self.num_frames-1: break self.propagating = False self.curr_frame_dirty = False self.on_pause() self.tl_slide() QApplication.processEvents() def pause_propagation(self): self.propagating = False def on_commit(self): self.complete_interaction() self.update_interacted_mask() def on_prev_frame(self): # self.tl_slide will trigger on setValue self.cursur = max(0, self.cursur-1) self.tl_slider.setValue(self.cursur) def on_next_frame(self): # self.tl_slide will trigger on setValue self.cursur = min(self.cursur+1, self.num_frames-1) self.tl_slider.setValue(self.cursur) def on_play_video_timer(self): self.cursur += 1 if self.cursur > self.num_frames-1: self.cursur = 0 self.tl_slider.setValue(self.cursur) def on_play_video(self): if self.timer.isActive(): self.timer.stop() self.play_button.setText('Play Video') else: self.timer.start(1000 // 30) self.play_button.setText('Stop Video') def on_export_visualization(self): # NOTE: Save visualization at the end of propagation image_folder = f"{self.config['workspace']}/visualization/" save_folder = self.config['workspace'] if os.path.exists(image_folder): # Sorted so frames will be in order self.console_push_text(f'Exporting visualization to {self.config["workspace"]}/visualization.mp4') images = [img for img in sorted(os.listdir(image_folder)) if img.endswith(".jpg")] frame = cv2.imread(os.path.join(image_folder, images[0])) height, width, layers = frame.shape # 10 is the FPS -- change if needed video = cv2.VideoWriter(f"{save_folder}/visualization.mp4", cv2.VideoWriter_fourcc(*'mp4v'), 10, (width,height)) for image in images: video.write(cv2.imread(os.path.join(image_folder, image))) video.release() self.console_push_text(f'Visualization exported to {self.config["workspace"]}/visualization.mp4') else: self.console_push_text(f'No visualization images found in {image_folder}') def on_object_dial_change(self): object_id = self.object_dial.value() self.hit_number_key(object_id) def on_reset_mask(self): self.current_mask.fill(0) if self.current_prob is not None: self.current_prob.fill_(0) self.curr_frame_dirty = True self.save_current_mask() self.reset_this_interaction() self.show_current_frame() def on_zoom_plus(self): self.zoom_pixels -= 25 self.zoom_pixels = max(50, self.zoom_pixels) self.update_minimap() def on_zoom_minus(self): self.zoom_pixels += 25 self.zoom_pixels = min(self.zoom_pixels, 300) self.update_minimap() def set_navi_enable(self, boolean): self.zoom_p_button.setEnabled(boolean) self.zoom_m_button.setEnabled(boolean) self.run_button.setEnabled(boolean) self.tl_slider.setEnabled(boolean) self.play_button.setEnabled(boolean) self.export_button.setEnabled(boolean) self.lcd.setEnabled(boolean) def hit_number_key(self, number): if number == self.current_object: return self.current_object = number self.object_dial.setValue(number) if self.fbrs_controller is not None: self.fbrs_controller.unanchor() self.console_push_text(f'Current object changed to {number}.') self.clear_brush() self.vis_brush(self.last_ex, self.last_ey) self.update_interact_vis() self.show_current_frame() def clear_brush(self): self.brush_vis_map.fill(0) self.brush_vis_alpha.fill(0) def vis_brush(self, ex, ey): self.brush_vis_map = cv2.circle(self.brush_vis_map, (int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1) self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha, (int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1) def on_mouse_press(self, event): if self.is_pos_out_of_bound(event.position().x(), event.position().y()): return # mid-click if (event.button() == Qt.MouseButton.MiddleButton): ex, ey = self.get_scaled_pos(event.position().x(), event.position().y()) target_object = self.current_mask[int(ey),int(ex)] if target_object in self.vis_target_objects: self.vis_target_objects.remove(target_object) else: self.vis_target_objects.append(target_object) self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}') self.show_current_frame() return self.right_click = (event.button() == Qt.MouseButton.RightButton) self.pressed = True h, w = self.height, self.width self.load_current_torch_image_mask() image = self.current_image_torch last_interaction = self.interaction new_interaction = None if self.curr_interaction == 'Scribble': if last_interaction is None or type(last_interaction) != ScribbleInteraction: self.complete_interaction() new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().to(self.device), (h, w), self.s2m_controller, self.num_objects) elif self.curr_interaction == 'Free': if last_interaction is None or type(last_interaction) != FreeInteraction: self.complete_interaction() new_interaction = FreeInteraction(image, self.current_mask, (h, w), self.num_objects) new_interaction.set_size(self.brush_size) elif self.curr_interaction == 'Click': if (last_interaction is None or type(last_interaction) != ClickInteraction or last_interaction.tar_obj != self.current_object): self.complete_interaction() self.fbrs_controller.unanchor() new_interaction = ClickInteraction(image, self.current_prob, (h, w), self.fbrs_controller, self.current_object) if new_interaction is not None: self.interaction = new_interaction # Just motion it as the first step self.on_mouse_motion(event) def on_mouse_motion(self, event): ex, ey = self.get_scaled_pos(event.position().x(), event.position().y()) self.last_ex, self.last_ey = ex, ey self.clear_brush() # Visualize self.vis_brush(ex, ey) if self.pressed: if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': obj = 0 if self.right_click else self.current_object self.vis_map, self.vis_alpha = self.interaction.push_point( ex, ey, obj, (self.vis_map, self.vis_alpha) ) self.update_interact_vis() self.update_minimap() def update_interacted_mask(self): self.current_prob = self.interacted_prob self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob) self.show_current_frame() self.save_current_mask() self.curr_frame_dirty = False def complete_interaction(self): if self.interaction is not None: self.clear_visualization() self.interaction = None def on_mouse_release(self, event): if not self.pressed: # this can happen when the initial press is out-of-bound return ex, ey = self.get_scaled_pos(event.position().x(), event.position().y()) self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur)) interaction = self.interaction if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': self.on_mouse_motion(event) interaction.end_path() if self.curr_interaction == 'Free': self.clear_visualization() elif self.curr_interaction == 'Click': ex, ey = self.get_scaled_pos(event.position().x(), event.position().y()) self.vis_map, self.vis_alpha = interaction.push_point(ex, ey, self.right_click, (self.vis_map, self.vis_alpha)) self.interacted_prob = interaction.predict().to(self.device) self.update_interacted_mask() self.update_gpu_usage() self.pressed = self.right_click = False def wheelEvent(self, event): ex, ey = self.get_scaled_pos(event.position().x(), event.position().y()) if self.curr_interaction == 'Free': self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30) self.clear_brush() self.vis_brush(ex, ey) self.update_interact_vis() self.update_minimap() def update_gpu_usage(self): if self.device.type == 'cuda': info = torch.cuda.mem_get_info() elif self.device.type == 'mps': info = (0, mps.current_allocated_memory()) # NOTE: torch.mps does not support accessing free and total memory else: info = (0, 0) global_free, global_total = info global_free /= (2**30) global_total /= (2**30) global_used = global_total - global_free self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB') self.gpu_mem_gauge.setValue(round(global_used/global_total*100)) used_by_torch = torch.cuda.max_memory_allocated() / (2**20) self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB') self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024)) def on_gpu_timer(self): self.update_gpu_usage() def update_memory_size(self): try: max_work_elements = self.processor.memory.max_work_elements max_long_elements = self.processor.memory.max_long_elements curr_work_elements = self.processor.memory.work_mem.size curr_long_elements = self.processor.memory.long_mem.size self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}') self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100)) self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}') self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100)) except AttributeError: self.work_mem_gauge.setFormat('Unknown') self.long_mem_gauge.setFormat('Unknown') self.work_mem_gauge.setValue(0) self.long_mem_gauge.setValue(0) def on_work_min_change(self): if self.initialized: self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1)) self.update_config() def on_work_max_change(self): if self.initialized: self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1)) self.update_config() def update_config(self): if self.initialized: self.config['min_mid_term_frames'] = self.work_mem_min.value() self.config['max_mid_term_frames'] = self.work_mem_max.value() self.config['max_long_term_elements'] = self.long_mem_max.value() self.config['num_prototypes'] = self.num_prototypes_box.value() self.config['mem_every'] = self.mem_every_box.value() self.processor.update_config(self.config) def on_clear_memory(self): self.processor.clear_memory() if self.device.type == 'cuda': torch.cuda.empty_cache() elif self.device.type == 'mps': mps.empty_cache() self.update_gpu_usage() self.update_memory_size() def _open_file(self, prompt): options = QFileDialog.Options() file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options) return file_name def on_import_mask(self): file_name = self._open_file('Mask') if len(file_name) == 0: return mask = self.res_man.read_external_image(file_name, size=(self.height, self.width)) shape_condition = ( (len(mask.shape) == 2) and (mask.shape[-1] == self.width) and (mask.shape[-2] == self.height) ) object_condition = ( mask.max() <= self.num_objects ) if not shape_condition: self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.') elif not object_condition: self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') else: self.console_push_text(f'Mask file {file_name} loaded.') self.current_image_torch = self.current_prob = None self.current_mask = mask self.show_current_frame() self.save_current_mask() def on_import_layer(self): file_name = self._open_file('Layer') if len(file_name) == 0: return self._try_load_layer(file_name) def _try_load_layer(self, file_name): try: layer = self.res_man.read_external_image(file_name, size=(self.height, self.width)) if layer.shape[-1] == 3: layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1) condition = ( (len(layer.shape) == 3) and (layer.shape[-1] == 4) and (layer.shape[-2] == self.width) and (layer.shape[-3] == self.height) ) if not condition: self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.') else: self.console_push_text(f'Layer file {file_name} loaded.') self.overlay_layer = layer self.overlay_layer_torch = torch.from_numpy(layer).float().to(self.device)/255 self.show_current_frame() except FileNotFoundError: self.console_push_text(f'{file_name} not found.') def on_save_visualization_toggle(self): self.save_visualization = self.save_visualization_checkbox.isChecked() ================================================ FILE: XMem/inference/interact/gui_utils.py ================================================ from PySide6.QtCore import Qt from PySide6.QtWidgets import (QBoxLayout, QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar) def create_parameter_box(min_val, max_val, text, step=1, callback=None): layout = QHBoxLayout() dial = QSpinBox() dial.setMaximumHeight(28) dial.setMaximumWidth(150) dial.setMinimum(min_val) dial.setMaximum(max_val) dial.setAlignment(Qt.AlignmentFlag.AlignRight) dial.setSingleStep(step) dial.valueChanged.connect(callback) label = QLabel(text) label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(label) layout.addWidget(dial) return dial, layout def create_gauge(text): layout = QHBoxLayout() gauge = QProgressBar() gauge.setMaximumHeight(28) gauge.setMaximumWidth(200) gauge.setAlignment(Qt.AlignmentFlag.AlignCenter) label = QLabel(text) label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(label) layout.addWidget(gauge) return gauge, layout def apply_to_all_children_widget(layout, func): # deliberately non-recursive for i in range(layout.count()): func(layout.itemAt(i).widget()) ================================================ FILE: XMem/inference/interact/interaction.py ================================================ """ Contains all the types of interaction related to the GUI Not related to automatic evaluation in the DAVIS dataset You can inherit the Interaction class to create new interaction types undo is (sometimes partially) supported """ import torch import torch.nn.functional as F import numpy as np import cv2 import time from .interactive_utils import color_map, index_numpy_to_one_hot_torch def aggregate_sbg(prob, keep_bg=False, hard=False): device = prob.device k, h, w = prob.shape ex_prob = torch.zeros((k+1, h, w), device=device) ex_prob[0] = 0.5 ex_prob[1:] = prob ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7) logits = torch.log((ex_prob /(1-ex_prob))) if hard: # Very low temperature o((⊙﹏⊙))o 🥶 logits *= 1000 if keep_bg: return F.softmax(logits, dim=0) else: return F.softmax(logits, dim=0)[1:] def aggregate_wbg(prob, keep_bg=False, hard=False): k, h, w = prob.shape new_prob = torch.cat([ torch.prod(1-prob, dim=0, keepdim=True), prob ], 0).clamp(1e-7, 1-1e-7) logits = torch.log((new_prob /(1-new_prob))) if hard: # Very low temperature o((⊙﹏⊙))o 🥶 logits *= 1000 if keep_bg: return F.softmax(logits, dim=0) else: return F.softmax(logits, dim=0)[1:] class Interaction: def __init__(self, image, prev_mask, true_size, controller): self.image = image self.prev_mask = prev_mask self.controller = controller self.start_time = time.time() self.h, self.w = true_size self.out_prob = None self.out_mask = None def predict(self): pass class FreeInteraction(Interaction): def __init__(self, image, prev_mask, true_size, num_objects): """ prev_mask should be index format numpy array """ super().__init__(image, prev_mask, true_size, None) self.K = num_objects self.drawn_map = self.prev_mask.copy() self.curr_path = [[] for _ in range(self.K + 1)] self.size = None def set_size(self, size): self.size = size """ k - object id vis - a tuple (visualization map, pass through alpha). None if not needed. """ def push_point(self, x, y, k, vis=None): if vis is not None: vis_map, vis_alpha = vis selected = self.curr_path[k] selected.append((x, y)) if len(selected) >= 2: cv2.line(self.drawn_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), k, thickness=self.size) # Plot visualization if vis is not None: # Visualization for drawing if k == 0: vis_map = cv2.line(vis_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), color_map[k], thickness=self.size) else: vis_map = cv2.line(vis_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), color_map[k], thickness=self.size) # Visualization on/off boolean filter vis_alpha = cv2.line(vis_alpha, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), 0.75, thickness=self.size) if vis is not None: return vis_map, vis_alpha def end_path(self): # Complete the drawing self.curr_path = [[] for _ in range(self.K + 1)] def predict(self): self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1) # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda() # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:]) # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True) return self.out_prob class ScribbleInteraction(Interaction): def __init__(self, image, prev_mask, true_size, controller, num_objects): """ prev_mask should be in an indexed form """ super().__init__(image, prev_mask, true_size, controller) self.K = num_objects self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8) self.drawn_map.fill(255) # background + k self.curr_path = [[] for _ in range(self.K + 1)] self.size = 3 """ k - object id vis - a tuple (visualization map, pass through alpha). None if not needed. """ def push_point(self, x, y, k, vis=None): if vis is not None: vis_map, vis_alpha = vis selected = self.curr_path[k] selected.append((x, y)) if len(selected) >= 2: self.drawn_map = cv2.line(self.drawn_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), k, thickness=self.size) # Plot visualization if vis is not None: # Visualization for drawing if k == 0: vis_map = cv2.line(vis_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), color_map[k], thickness=self.size) else: vis_map = cv2.line(vis_map, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), color_map[k], thickness=self.size) # Visualization on/off boolean filter vis_alpha = cv2.line(vis_alpha, (int(round(selected[-2][0])), int(round(selected[-2][1]))), (int(round(selected[-1][0])), int(round(selected[-1][1]))), 0.75, thickness=self.size) # Optional vis return if vis is not None: return vis_map, vis_alpha def end_path(self): # Complete the drawing self.curr_path = [[] for _ in range(self.K + 1)] def predict(self): self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map) self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True) return self.out_prob class ClickInteraction(Interaction): def __init__(self, image, prev_mask, true_size, controller, tar_obj): """ prev_mask in a prob. form """ super().__init__(image, prev_mask, true_size, controller) self.tar_obj = tar_obj # negative/positive for each object self.pos_clicks = [] self.neg_clicks = [] self.out_prob = self.prev_mask.clone() """ neg - Negative interaction or not vis - a tuple (visualization map, pass through alpha). None if not needed. """ def push_point(self, x, y, neg, vis=None): # Clicks if neg: self.neg_clicks.append((x, y)) else: self.pos_clicks.append((x, y)) # Do the prediction self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg) # Plot visualization if vis is not None: vis_map, vis_alpha = vis # Visualization for clicks if neg: vis_map = cv2.circle(vis_map, (int(round(x)), int(round(y))), 2, color_map[0], thickness=-1) else: vis_map = cv2.circle(vis_map, (int(round(x)), int(round(y))), 2, color_map[self.tar_obj], thickness=-1) vis_alpha = cv2.circle(vis_alpha, (int(round(x)), int(round(y))), 2, 1, thickness=-1) # Optional vis return return vis_map, vis_alpha def predict(self): self.out_prob = self.prev_mask.clone() # a small hack to allow the interacting object to overwrite existing masks # without remembering all the object probabilities self.out_prob = torch.clamp(self.out_prob, max=0.9) self.out_prob[self.tar_obj] = self.obj_mask self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True) return self.out_prob ================================================ FILE: XMem/inference/interact/interactive_utils.py ================================================ # Modifed from https://github.com/seoungwugoh/ivs-demo import numpy as np import torch import torch.nn.functional as F from util.palette import davis_palette from dataset.range_transform import im_normalization def image_to_torch(frame: np.ndarray, device='cuda'): # frame: H*W*3 numpy array frame = frame.transpose(2, 0, 1) frame = torch.from_numpy(frame).float().to(device)/255 frame_norm = im_normalization(frame) return frame_norm, frame def torch_prob_to_numpy_mask(prob): mask = torch.max(prob, dim=0).indices mask = mask.cpu().numpy().astype(np.uint8) return mask def index_numpy_to_one_hot_torch(mask, num_classes): mask = torch.from_numpy(mask).long() return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float() """ Some constants fro visualization """ try: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") except: device = torch.device("cpu") color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy() # scales for better visualization color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8) color_map = color_map_np.tolist() color_map_torch = torch.from_numpy(color_map_np).to(device) / 255 grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32) grayscale_weights_torch = torch.from_numpy(grayscale_weights).to(device).unsqueeze(0) def get_visualization(mode, image, mask, layer, target_object): if mode == 'fade': return overlay_davis(image, mask, fade=True) elif mode == 'davis': return overlay_davis(image, mask) elif mode == 'light': return overlay_davis(image, mask, 0.9) elif mode == 'popup': return overlay_popup(image, mask, target_object) elif mode == 'layered': if layer is None: print('Layer file not given. Defaulting to DAVIS.') return overlay_davis(image, mask) else: return overlay_layer(image, mask, layer, target_object) else: raise NotImplementedError def get_visualization_torch(mode, image, prob, layer, target_object): if mode == 'fade': return overlay_davis_torch(image, prob, fade=True) elif mode == 'davis': return overlay_davis_torch(image, prob) elif mode == 'light': return overlay_davis_torch(image, prob, 0.9) elif mode == 'popup': return overlay_popup_torch(image, prob, target_object) elif mode == 'layered': if layer is None: print('Layer file not given. Defaulting to DAVIS.') return overlay_davis_torch(image, prob) else: return overlay_layer_torch(image, prob, layer, target_object) else: raise NotImplementedError def overlay_davis(image, mask, alpha=0.5, fade=False): """ Overlay segmentation on top of RGB image. from davis official""" im_overlay = image.copy() colored_mask = color_map_np[mask] foreground = image*alpha + (1-alpha)*colored_mask binary_mask = (mask > 0) # Compose image im_overlay[binary_mask] = foreground[binary_mask] if fade: im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6 return im_overlay.astype(image.dtype) def overlay_popup(image, mask, target_object): # Keep foreground colored. Convert background to grayscale. im_overlay = image.copy() binary_mask = ~(np.isin(mask, target_object)) colored_region = (im_overlay[binary_mask]*grayscale_weights).sum(-1, keepdims=-1) im_overlay[binary_mask] = colored_region return im_overlay.astype(image.dtype) def overlay_layer(image, mask, layer, target_object): # insert a layer between foreground and background # The CPU version is less accurate because we are using the hard mask # The GPU version has softer edges as it uses soft probabilities obj_mask = (np.isin(mask, target_object)).astype(np.float32)[:, :, np.newaxis] layer_alpha = layer[:, :, 3].astype(np.float32)[:, :, np.newaxis] / 255 layer_rgb = layer[:, :, :3] background_alpha = np.maximum(obj_mask, layer_alpha) im_overlay = (image * (1 - background_alpha) + layer_rgb * (1 - obj_mask) * layer_alpha + image * obj_mask).clip(0, 255) return im_overlay.astype(image.dtype) def overlay_davis_torch(image, mask, alpha=0.5, fade=False): """ Overlay segmentation on top of RGB image. from davis official""" # Changes the image in-place to avoid copying image = image.permute(1, 2, 0) im_overlay = image mask = torch.max(mask, dim=0).indices colored_mask = color_map_torch[mask] foreground = image*alpha + (1-alpha)*colored_mask binary_mask = (mask > 0) # Compose image im_overlay[binary_mask] = foreground[binary_mask] if fade: im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6 im_overlay = (im_overlay*255).cpu().numpy() im_overlay = im_overlay.astype(np.uint8) return im_overlay def overlay_popup_torch(image, mask, target_object): # Keep foreground colored. Convert background to grayscale. image = image.permute(1, 2, 0) if len(target_object) == 0: obj_mask = torch.zeros_like(mask[0]).unsqueeze(2) else: # I should not need to convert this to numpy. # uUsing list works most of the time but consistently fails # if I include first object -> exclude it -> include it again. # I check everywhere and it makes absolutely no sense. # I am blaming this on PyTorch and calling it a day obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0).unsqueeze(2) gray_image = (image*grayscale_weights_torch).sum(-1, keepdim=True) im_overlay = obj_mask*image + (1-obj_mask)*gray_image im_overlay = (im_overlay*255).cpu().numpy() im_overlay = im_overlay.astype(np.uint8) return im_overlay def overlay_layer_torch(image, prob, layer, target_object): # insert a layer between foreground and background # The CPU version is less accurate because we are using the hard mask # The GPU version has softer edges as it uses soft probabilities image = image.permute(1, 2, 0) if len(target_object) == 0: obj_mask = torch.zeros_like(prob[0]).unsqueeze(2) else: # TODO: figure out why we need to convert this to numpy array obj_mask = prob[np.array(target_object, dtype=np.int32)].sum(0).unsqueeze(2) layer_alpha = layer[:, :, 3].unsqueeze(2) layer_rgb = layer[:, :, :3] background_alpha = torch.maximum(obj_mask, layer_alpha) im_overlay = (image * (1 - background_alpha) + layer_rgb * (1 - obj_mask) * layer_alpha + image * obj_mask).clip(0, 1) im_overlay = (im_overlay * 255).cpu().numpy() im_overlay = im_overlay.astype(np.uint8) return im_overlay ================================================ FILE: XMem/inference/interact/resource_manager.py ================================================ import os from os import path import shutil import collections import cv2 from PIL import Image if not hasattr(Image, 'Resampling'): # Pillow<9.0 Image.Resampling = Image import numpy as np from util.palette import davis_palette import progressbar # https://bugs.python.org/issue28178 # ah python ah why class LRU: def __init__(self, func, maxsize=128): self.cache = collections.OrderedDict() self.func = func self.maxsize = maxsize def __call__(self, *args): cache = self.cache if args in cache: cache.move_to_end(args) return cache[args] result = self.func(*args) cache[args] = result if len(cache) > self.maxsize: cache.popitem(last=False) return result def invalidate(self, key): self.cache.pop(key, None) class ResourceManager: def __init__(self, config): # determine inputs images = config['images'] video = config['video'] self.workspace = config['workspace'] self.size = config['size'] self.palette = davis_palette # create temporary workspace if not specified if self.workspace is None: if images is not None: basename = path.basename(images) elif video is not None: basename = path.basename(video)[:-4] else: raise NotImplementedError( 'Either images, video, or workspace has to be specified') self.workspace = path.join('./workspace', basename) print(f'Workspace is in: {self.workspace}') # determine the location of input images need_decoding = False need_resizing = False if path.exists(path.join(self.workspace, 'images')): pass elif images is not None: need_resizing = True elif video is not None: # will decode video into frames later need_decoding = True # create workspace subdirectories self.image_dir = path.join(self.workspace, 'images') self.mask_dir = path.join(self.workspace, 'masks') os.makedirs(self.image_dir, exist_ok=True) os.makedirs(self.mask_dir, exist_ok=True) # convert read functions to be buffered self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size']) self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size']) # extract frames from video if need_decoding: self._extract_frames(video) # copy/resize existing images to the workspace if need_resizing: self._copy_resize_frames(images) # read all frame names self.names = sorted(os.listdir(self.image_dir)) self.names = [f[:-4] for f in self.names] # remove extensions self.length = len(self.names) assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.' print(f'{self.length} images found.') self.height, self.width = self.get_image(0).shape[:2] self.visualization_init = False def _extract_frames(self, video): cap = cv2.VideoCapture(video) frame_index = 0 print(f'Extracting frames from {video} into {self.image_dir}...') bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength) while(cap.isOpened()): _, frame = cap.read() if frame is None: break if self.size > 0: h, w = frame.shape[:2] new_w = (w*self.size//min(w, h)) new_h = (h*self.size//min(w, h)) if new_w != w or new_h != h: frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame) frame_index += 1 bar.update(frame_index) bar.finish() print('Done!') def _copy_resize_frames(self, images): image_list = os.listdir(images) print(f'Copying/resizing frames into {self.image_dir}...') for image_name in progressbar.progressbar(image_list): if self.size < 0: # just copy shutil.copy2(path.join(images, image_name), self.image_dir) else: frame = cv2.imread(path.join(images, image_name)) h, w = frame.shape[:2] new_w = (w*self.size//min(w, h)) new_h = (h*self.size//min(w, h)) if new_w != w or new_h != h: frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) cv2.imwrite(path.join(self.image_dir, image_name), frame) print('Done!') def save_mask(self, ti, mask): # mask should be uint8 H*W without channels assert 0 <= ti < self.length assert isinstance(mask, np.ndarray) mask = Image.fromarray(mask) mask.putpalette(self.palette) mask.save(path.join(self.mask_dir, self.names[ti]+'.png')) self.invalidate(ti) def save_visualization(self, ti, image): # image should be uint8 3*H*W assert 0 <= ti < self.length assert isinstance(image, np.ndarray) if not self.visualization_init: self.visualization_dir = path.join(self.workspace, 'visualization') os.makedirs(self.visualization_dir, exist_ok=True) self.visualization_init = True image = Image.fromarray(image) image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg')) def _get_image_unbuffered(self, ti): # returns H*W*3 uint8 array assert 0 <= ti < self.length image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg')) image = np.array(image) return image def _get_mask_unbuffered(self, ti): # returns H*W uint8 array assert 0 <= ti < self.length mask_path = path.join(self.mask_dir, self.names[ti]+'.png') if path.exists(mask_path): mask = Image.open(mask_path) mask = np.array(mask) return mask else: return None def read_external_image(self, file_name, size=None): image = Image.open(file_name) is_mask = image.mode in ['L', 'P'] if size is not None: # PIL uses (width, height) image = image.resize((size[1], size[0]), resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC) image = np.array(image) return image def invalidate(self, ti): # the image buffer is never invalidated self.get_mask.invalidate((ti,)) def __len__(self): return self.length @property def h(self): return self.height @property def w(self): return self.width ================================================ FILE: XMem/inference/interact/s2m/__init__.py ================================================ ================================================ FILE: XMem/inference/interact/s2m/_deeplab.py ================================================ # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch import torch from torch import nn from torch.nn import functional as F from .utils import _SimpleSegmentationModel __all__ = ["DeepLabV3"] class DeepLabV3(_SimpleSegmentationModel): """ Implements DeepLabV3 model from `"Rethinking Atrous Convolution for Semantic Image Segmentation" `_. Arguments: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier is used. classifier (nn.Module): module that takes the "out" element returned from the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ pass class DeepLabHeadV3Plus(nn.Module): def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): super(DeepLabHeadV3Plus, self).__init__() self.project = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True), ) self.aspp = ASPP(in_channels, aspp_dilate) self.classifier = nn.Sequential( nn.Conv2d(304, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, 1) ) self._init_weight() def forward(self, feature): low_level_feature = self.project( feature['low_level'] ) output_feature = self.aspp(feature['out']) output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class DeepLabHead(nn.Module): def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): super(DeepLabHead, self).__init__() self.classifier = nn.Sequential( ASPP(in_channels, aspp_dilate), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, 1) ) self._init_weight() def forward(self, feature): return self.classifier( feature['out'] ) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class AtrousSeparableConvolution(nn.Module): """ Atrous Separable Convolution """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): super(AtrousSeparableConvolution, self).__init__() self.body = nn.Sequential( # Separable Conv nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), # PointWise Conv nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), ) self._init_weight() def forward(self, x): return self.body(x) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ] super(ASPPConv, self).__init__(*modules) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self, x): size = x.shape[-2:] x = super(ASPPPooling, self).forward(x) return F.interpolate(x, size=size, mode='bilinear', align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, atrous_rates): super(ASPP, self).__init__() out_channels = 256 modules = [] modules.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))) rate1, rate2, rate3 = tuple(atrous_rates) modules.append(ASPPConv(in_channels, out_channels, rate1)) modules.append(ASPPConv(in_channels, out_channels, rate2)) modules.append(ASPPConv(in_channels, out_channels, rate3)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout(0.1),) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res) def convert_to_separable_conv(module): new_module = module if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: new_module = AtrousSeparableConvolution(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.bias) for name, child in module.named_children(): new_module.add_module(name, convert_to_separable_conv(child)) return new_module ================================================ FILE: XMem/inference/interact/s2m/s2m_network.py ================================================ # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch from .utils import IntermediateLayerGetter from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 from . import s2m_resnet def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): if output_stride==8: replace_stride_with_dilation=[False, True, True] aspp_dilate = [12, 24, 36] else: replace_stride_with_dilation=[False, False, True] aspp_dilate = [6, 12, 18] backbone = s2m_resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=replace_stride_with_dilation) inplanes = 2048 low_level_planes = 256 if name=='deeplabv3plus': return_layers = {'layer4': 'out', 'layer1': 'low_level'} classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) elif name=='deeplabv3': return_layers = {'layer4': 'out'} classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) model = DeepLabV3(backbone, classifier) return model def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): if backbone.startswith('resnet'): model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) else: raise NotImplementedError return model # Deeplab v3 def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: num_classes (int): number of classes. output_stride (int): output stride for deeplab. pretrained_backbone (bool): If True, use the pretrained backbone. """ return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) # Deeplab v3+ def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: num_classes (int): number of classes. output_stride (int): output stride for deeplab. pretrained_backbone (bool): If True, use the pretrained backbone. """ return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) ================================================ FILE: XMem/inference/interact/s2m/s2m_resnet.py ================================================ import torch import torch.nn as nn try: from torchvision.models.utils import load_state_dict_from_url except ModuleNotFoundError: from torch.utils.model_zoo import load_url as load_state_dict_from_url __all__ = ['ResNet', 'resnet50'] model_urls = { 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', } def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def resnet50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) ================================================ FILE: XMem/inference/interact/s2m/utils.py ================================================ # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from collections import OrderedDict class _SimpleSegmentationModel(nn.Module): def __init__(self, backbone, classifier): super(_SimpleSegmentationModel, self).__init__() self.backbone = backbone self.classifier = classifier def forward(self, x): input_shape = x.shape[-2:] features = self.backbone(x) x = self.classifier(features) x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) return x class IntermediateLayerGetter(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model It has a strong assumption that the modules have been registered into the model in the same order as they are used. This means that one should **not** reuse the same nn.Module twice in the forward if you want this to work. Additionally, it is only able to query submodules that are directly assigned to the model. So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. Arguments: model (nn.Module): model on which we will extract the features return_layers (Dict[name, new_name]): a dict containing the names of the modules for which the activations will be returned as the key of the dict, and the value of the dict is the name of the returned activation (which the user can specify). Examples:: >>> m = torchvision.models.resnet18(pretrained=True) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = new_m(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] """ def __init__(self, model, return_layers): if not set(return_layers).issubset([name for name, _ in model.named_children()]): raise ValueError("return_layers are not present in model") orig_return_layers = return_layers return_layers = {k: v for k, v in return_layers.items()} layers = OrderedDict() for name, module in model.named_children(): layers[name] = module if name in return_layers: del return_layers[name] if not return_layers: break super(IntermediateLayerGetter, self).__init__(layers) self.return_layers = orig_return_layers def forward(self, x): out = OrderedDict() for name, module in self.named_children(): x = module(x) if name in self.return_layers: out_name = self.return_layers[name] out[out_name] = x return out ================================================ FILE: XMem/inference/interact/s2m_controller.py ================================================ import torch import numpy as np from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M from util.tensor_util import pad_divide_by, unpad class S2MController: """ A controller for Scribble-to-Mask (for user interaction, not for DAVIS) Takes the image, previous mask, and scribbles to produce a new mask ignore_class is usually 255 0 is NOT the ignore class -- it is the label for the background """ def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'): self.s2m_net = s2m_net self.num_objects = num_objects self.ignore_class = ignore_class self.device = device def interact(self, image, prev_mask, scr_mask): print(self.device) image = image.to(self.device, non_blocking=True) prev_mask = prev_mask.unsqueeze(0) h, w = image.shape[-2:] unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device) for ki in range(1, self.num_objects+1): p_srb = (scr_mask==ki).astype(np.uint8) n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8) Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device) inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1) inputs, pads = pad_divide_by(inputs, 16) unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads) return unaggre_mask ================================================ FILE: XMem/inference/interact/timer.py ================================================ import time class Timer: def __init__(self): self._acc_time = 0 self._paused = True def start(self): if self._paused: self.last_time = time.time() self._paused = False return self def pause(self): self.count() self._paused = True return self def count(self): if self._paused: return self._acc_time t = time.time() self._acc_time += t - self.last_time self.last_time = t return self._acc_time def format(self): # count = int(self.count()*100) # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100) return '%03.2f' % self.count() def __str__(self): return self.format() ================================================ FILE: XMem/inference/kv_memory_store.py ================================================ import torch from typing import List class KeyValueMemoryStore: """ Works for key/value pairs type storage e.g., working and long-term memory """ """ An object group is created when new objects enter the video Objects in the same group share the same temporal extent i.e., objects initialized in the same frame are in the same group For DAVIS/interactive, there is only one object group For YouTubeVOS, there can be multiple object groups """ def __init__(self, count_usage: bool): self.count_usage = count_usage # keys are stored in a single tensor and are shared between groups/objects # values are stored as a list indexed by object groups self.k = None self.v = [] self.obj_groups = [] # for debugging only self.all_objects = [] # shrinkage and selection are also single tensors self.s = self.e = None # usage if self.count_usage: self.use_count = self.life_count = None def add(self, key, value, shrinkage, selection, objects: List[int]): new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7 # add the key if self.k is None: self.k = key self.s = shrinkage self.e = selection if self.count_usage: self.use_count = new_count self.life_count = new_life else: self.k = torch.cat([self.k, key], -1) if shrinkage is not None: self.s = torch.cat([self.s, shrinkage], -1) if selection is not None: self.e = torch.cat([self.e, selection], -1) if self.count_usage: self.use_count = torch.cat([self.use_count, new_count], -1) self.life_count = torch.cat([self.life_count, new_life], -1) # add the value if objects is not None: # When objects is given, v is a tensor; used in working memory assert isinstance(value, torch.Tensor) # First consume objects that are already in the memory bank # cannot use set here because we need to preserve order # shift by one as background is not part of value remaining_objects = [obj-1 for obj in objects] for gi, group in enumerate(self.obj_groups): for obj in group: # should properly raise an error if there are overlaps in obj_groups remaining_objects.remove(obj) self.v[gi] = torch.cat([self.v[gi], value[group]], -1) # If there are remaining objects, add them as a new group if len(remaining_objects) > 0: new_group = list(remaining_objects) self.v.append(value[new_group]) self.obj_groups.append(new_group) self.all_objects.extend(new_group) assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order ' else: # When objects is not given, v is a list that already has the object groups sorted # used in long-term memory assert isinstance(value, list) for gi, gv in enumerate(value): if gv is None: continue if gi < self.num_groups: self.v[gi] = torch.cat([self.v[gi], gv], -1) else: self.v.append(gv) def update_usage(self, usage): # increase all life count by 1 # increase use of indexed elements if not self.count_usage: return self.use_count += usage.view_as(self.use_count) self.life_count += 1 def sieve_by_range(self, start: int, end: int, min_size: int): # keep only the elements *outside* of this range (with some boundary conditions) # i.e., concat (a[:start], a[end:]) # min_size is only used for values, we do not sieve values under this size # (because they are not consolidated) if end == 0: # negative 0 would not work as the end index! self.k = self.k[:,:,:start] if self.count_usage: self.use_count = self.use_count[:,:,:start] self.life_count = self.life_count[:,:,:start] if self.s is not None: self.s = self.s[:,:,:start] if self.e is not None: self.e = self.e[:,:,:start] for gi in range(self.num_groups): if self.v[gi].shape[-1] >= min_size: self.v[gi] = self.v[gi][:,:,:start] else: self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) if self.count_usage: self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) if self.s is not None: self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) if self.e is not None: self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) for gi in range(self.num_groups): if self.v[gi].shape[-1] >= min_size: self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) def remove_obsolete_features(self, max_size: int): # normalize with life duration usage = self.get_usage().flatten() values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True) survived = (usage > values[-1]) self.k = self.k[:, :, survived] self.s = self.s[:, :, survived] if self.s is not None else None # Long-term memory does not store ek so this should not be needed self.e = self.e[:, :, survived] if self.e is not None else None if self.num_groups > 1: raise NotImplementedError("""The current data structure does not support feature removal with multiple object groups (e.g., some objects start to appear later in the video) The indices for "survived" is based on keys but not all values are present for every key Basically we need to remap the indices for keys to values """) for gi in range(self.num_groups): self.v[gi] = self.v[gi][:, :, survived] self.use_count = self.use_count[:, :, survived] self.life_count = self.life_count[:, :, survived] def get_usage(self): # return normalized usage if not self.count_usage: raise RuntimeError('I did not count usage!') else: usage = self.use_count / self.life_count return usage def get_all_sliced(self, start: int, end: int): # return k, sk, ek, usage in order, sliced by start and end if end == 0: # negative 0 would not work as the end index! k = self.k[:,:,start:] sk = self.s[:,:,start:] if self.s is not None else None ek = self.e[:,:,start:] if self.e is not None else None usage = self.get_usage()[:,:,start:] else: k = self.k[:,:,start:end] sk = self.s[:,:,start:end] if self.s is not None else None ek = self.e[:,:,start:end] if self.e is not None else None usage = self.get_usage()[:,:,start:end] return k, sk, ek, usage def get_v_size(self, ni: int): return self.v[ni].shape[2] def engaged(self): return self.k is not None @property def size(self): if self.k is None: return 0 else: return self.k.shape[-1] @property def num_groups(self): return len(self.v) @property def key(self): return self.k @property def value(self): return self.v @property def shrinkage(self): return self.s @property def selection(self): return self.e ================================================ FILE: XMem/inference/memory_manager.py ================================================ import torch import warnings from inference.kv_memory_store import KeyValueMemoryStore from model.memory_util import * class MemoryManager: """ Manages all three memory stores and the transition between working/long-term memory """ def __init__(self, config): self.hidden_dim = config['hidden_dim'] self.top_k = config['top_k'] self.enable_long_term = config['enable_long_term'] self.enable_long_term_usage = config['enable_long_term_count_usage'] if self.enable_long_term: self.max_mt_frames = config['max_mid_term_frames'] self.min_mt_frames = config['min_mid_term_frames'] self.num_prototypes = config['num_prototypes'] self.max_long_elements = config['max_long_term_elements'] # dimensions will be inferred from input later self.CK = self.CV = None self.H = self.W = None # The hidden state will be stored in a single tensor for all objects # B x num_objects x CH x H x W self.hidden = None self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) self.reset_config = True def update_config(self, config): self.reset_config = True self.hidden_dim = config['hidden_dim'] self.top_k = config['top_k'] assert self.enable_long_term == config['enable_long_term'], 'cannot update this' assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this' self.enable_long_term_usage = config['enable_long_term_count_usage'] if self.enable_long_term: self.max_mt_frames = config['max_mid_term_frames'] self.min_mt_frames = config['min_mid_term_frames'] self.num_prototypes = config['num_prototypes'] self.max_long_elements = config['max_long_term_elements'] def _readout(self, affinity, v): # this function is for a single object group return v @ affinity def match_memory(self, query_key, selection): # query_key: B x C^k x H x W # selection: B x C^k x H x W num_groups = self.work_mem.num_groups h, w = query_key.shape[-2:] query_key = query_key.flatten(start_dim=2) selection = selection.flatten(start_dim=2) if selection is not None else None """ Memory readout using keys """ if self.enable_long_term and self.long_mem.engaged(): # Use long-term memory long_mem_size = self.long_mem.size memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1) similarity = get_similarity(memory_key, shrinkage, query_key, selection) work_mem_similarity = similarity[:, long_mem_size:] long_mem_similarity = similarity[:, :long_mem_size] # get the usage with the first group # the first group always have all the keys valid affinity, usage = do_softmax( torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1), top_k=self.top_k, inplace=True, return_usage=True) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): if gi < self.long_mem.num_groups: # merge working and lt similarities before softmax affinity_one_group = do_softmax( torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1), top_k=self.top_k, inplace=True) else: # no long-term memory for this group affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):], top_k=self.top_k, inplace=(gi==num_groups-1)) affinity.append(affinity_one_group) all_memory_value = [] for gi, gv in enumerate(self.work_mem.value): # merge the working and lt values before readout if gi < self.long_mem.num_groups: all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1)) else: all_memory_value.append(gv) """ Record memory usage for working and long-term memory """ # ignore the index return for long-term memory work_usage = usage[:, long_mem_size:] self.work_mem.update_usage(work_usage.flatten()) if self.enable_long_term_usage: # ignore the index return for working memory long_usage = usage[:, :long_mem_size] self.long_mem.update_usage(long_usage.flatten()) else: # No long-term memory similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection) if self.enable_long_term: affinity, usage = do_softmax(similarity, inplace=(num_groups==1), top_k=self.top_k, return_usage=True) # Record memory usage for working memory self.work_mem.update_usage(usage.flatten()) else: affinity = do_softmax(similarity, inplace=(num_groups==1), top_k=self.top_k, return_usage=False) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], top_k=self.top_k, inplace=(gi==num_groups-1)) affinity.append(affinity_one_group) all_memory_value = self.work_mem.value # Shared affinity within each group all_readout_mem = torch.cat([ self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value) ], 0) return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) def add_memory(self, key, shrinkage, value, objects, selection=None): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices if self.H is None or self.reset_config: self.reset_config = False self.H, self.W = key.shape[-2:] self.HW = self.H*self.W if self.enable_long_term: # convert from num. frames to num. nodes self.min_work_elements = self.min_mt_frames*self.HW self.max_work_elements = self.max_mt_frames*self.HW # key: 1*C*N # value: num_objects*C*N key = key.flatten(start_dim=2) shrinkage = shrinkage.flatten(start_dim=2) value = value[0].flatten(start_dim=2) self.CK = key.shape[1] self.CV = value.shape[1] if selection is not None: if not self.enable_long_term: warnings.warn('the selection factor is only needed in long-term mode', UserWarning) selection = selection.flatten(start_dim=2) self.work_mem.add(key, value, shrinkage, selection, objects) # long-term memory cleanup if self.enable_long_term: # Do memory compressed if needed if self.work_mem.size >= self.max_work_elements: # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) self.compress_features() def create_hidden_state(self, n, sample_key): # n is the TOTAL number of objects h, w = sample_key.shape[-2:] if self.hidden is None: self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device) elif self.hidden.shape[1] != n: self.hidden = torch.cat([ self.hidden, torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device) ], 1) assert(self.hidden.shape[1] == n) def set_hidden(self, hidden): self.hidden = hidden def get_hidden(self): return self.hidden def compress_features(self): HW = self.HW candidate_value = [] total_work_mem_size = self.work_mem.size for gv in self.work_mem.value: # Some object groups might be added later in the video # So not all keys have values associated with all objects # We need to keep track of the key->value validity mem_size_in_this_group = gv.shape[-1] if mem_size_in_this_group == total_work_mem_size: # full LT candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) else: # mem_size is smaller than total_work_mem_size, but at least HW assert HW <= mem_size_in_this_group < total_work_mem_size if mem_size_in_this_group > self.min_work_elements+HW: # part of this object group still goes into LT candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) else: # this object group cannot go to the LT at all candidate_value.append(None) # perform memory consolidation prototype_key, prototype_value, prototype_shrinkage = self.consolidation( *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) # remove consolidated working memory self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) # add to long-term memory self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value): # keys: 1*C*N # values: num_objects*C*N N = candidate_key.shape[-1] # find the indices with max usage _, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True) prototype_indices = max_usage_indices.flatten() # Prototypes are invalid for out-of-bound groups validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value] prototype_key = candidate_key[:, :, prototype_indices] prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None """ Potentiation step """ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection) # convert similarity to affinity # need to do it group by group since the softmax normalization would be different affinity = [ do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None for gi, gv in enumerate(candidate_value) ] # some values can be have all False validity. Weed them out. affinity = [ aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity ] # readout the values prototype_value = [ self._readout(affinity[gi], gv) if affinity[gi] is not None else None for gi, gv in enumerate(candidate_value) ] # readout the shrinkage term prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None return prototype_key, prototype_value, prototype_shrinkage ================================================ FILE: XMem/interactive_demo.py ================================================ """ A simple user interface for XMem """ import os from os import path # fix for Windows if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ: os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = '' import signal signal.signal(signal.SIGINT, signal.SIG_DFL) import sys from argparse import ArgumentParser import torch from model.network import XMem from inference.interact.s2m_controller import S2MController from inference.interact.fbrs_controller import FBRSController from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M from PySide6.QtWidgets import QApplication from inference.interact.gui import App from inference.interact.resource_manager import ResourceManager from contextlib import nullcontext torch.set_grad_enabled(False) if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") if __name__ == '__main__': # Arguments parsing parser = ArgumentParser() parser.add_argument('--model', default='./saves/XMem.pth') parser.add_argument('--s2m_model', default='saves/s2m.pth') parser.add_argument('--fbrs_model', default='saves/fbrs.pth') """ Priority 1: If a "images" folder exists in the workspace, we will read from that directory Priority 2: If --images is specified, we will copy/resize those images to the workspace Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask That way, you can continue annotation from an interrupted run as long as the same workspace is used. """ parser.add_argument('--images', help='Folders containing input images.', default=None) parser.add_argument('--video', help='Video file readable by OpenCV.', default=None) parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None) parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100) parser.add_argument('--num_objects', type=int, default=1) # Long-memory options # Defaults. Some can be changed in the GUI. parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', type=int, default=10000) parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) parser.add_argument('--top_k', type=int, default=30) parser.add_argument('--mem_every', type=int, default=10) parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) parser.add_argument('--no_amp', help='Turn off AMP', action='store_true') parser.add_argument('--size', default=480, type=int, help='Resize the shorter side to this size. -1 to use original resolution. ') args = parser.parse_args() # create temporary workspace if not specified config = vars(args) config['enable_long_term'] = True config['enable_long_term_count_usage'] = True if config["workspace"] is None: if config["images"] is not None: basename = path.basename(config["images"]) elif config["video"] is not None: basename = path.basename(config["video"])[:-4] else: raise NotImplementedError( 'Either images, video, or workspace has to be specified') config["workspace"] = path.join('./workspace', basename) with torch.cuda.amp.autocast(enabled=not args.no_amp) if device.type == 'cuda' else nullcontext(): # Load our checkpoint network = XMem(config, args.model, map_location=device).to(device).eval() # Loads the S2M model if args.s2m_model is not None: s2m_saved = torch.load(args.s2m_model, map_location=device) s2m_model = S2M().to(device).eval() s2m_model.load_state_dict(s2m_saved) else: s2m_model = None s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255, device=device) if args.fbrs_model is not None: fbrs_controller = FBRSController(args.fbrs_model, device=device) else: fbrs_controller = None # Manages most IO resource_manager = ResourceManager(config) app = QApplication(sys.argv) ex = App(network, resource_manager, s2m_controller, fbrs_controller, config, device) sys.exit(app.exec()) ================================================ FILE: XMem/merge_multi_scale.py ================================================ import os from os import path from argparse import ArgumentParser import glob from collections import defaultdict import numpy as np import hickle as hkl from PIL import Image, ImagePalette from progressbar import progressbar from multiprocessing import Pool from util import palette from util.palette import davis_palette, youtube_palette import shutil def search_options(options, name): for option in options: if path.exists(path.join(option, name)): return path.join(option, name) else: return None def process_vid(vid): vid_path = search_options(all_options, vid) if vid_path is not None: backward_mapping = hkl.load(path.join(vid_path, 'backward.hkl')) else: backward_mapping = None frames = os.listdir(path.join(all_options[0], vid)) frames = [f for f in frames if 'backward' not in f] print(vid) if 'Y' in args.dataset: this_out_path = path.join(out_path, 'Annotations', vid) else: this_out_path = path.join(out_path, vid) os.makedirs(this_out_path, exist_ok=True) for f in frames: result_sum = None for option in all_options: if not path.exists(path.join(option, vid, f)): continue result = hkl.load(path.join(option, vid, f)) if result_sum is None: result_sum = result.astype(np.float32) else: result_sum += result # argmax and to idx result_sum = np.argmax(result_sum, axis=0) # Remap the indices to the original domain if backward_mapping is not None: idx_mask = np.zeros_like(result_sum, dtype=np.uint8) for l, i in backward_mapping.items(): idx_mask[result_sum==i] = l else: idx_mask = result_sum.astype(np.uint8) # Save the results img_E = Image.fromarray(idx_mask) img_E.putpalette(palette) img_E.save(path.join(this_out_path, f[:-4]+'.png')) if __name__ == '__main__': """ Arguments loading """ parser = ArgumentParser() parser.add_argument('--dataset', default='Y', help='D/Y, D for DAVIS; Y for YouTubeVOS') parser.add_argument('--list', nargs="+") parser.add_argument('--pattern', default=None, help='Glob patten. Can be used in place of list.') parser.add_argument('--output') parser.add_argument('--num_proc', default=4, type=int) args = parser.parse_args() out_path = args.output # Find the input candidates if args.pattern is None: all_options = args.list else: assert args.list is None, 'cannot specify both list and pattern' all_options = glob.glob(args.pattern) # Get the correct palette if 'D' in args.dataset: palette = ImagePalette.ImagePalette(mode='P', palette=davis_palette) elif 'Y' in args.dataset: palette = ImagePalette.ImagePalette(mode='P', palette=youtube_palette) else: raise NotImplementedError # Count of the number of videos in each candidate all_options = [path.join(o, 'Scores') for o in all_options] vid_count = defaultdict(int) for option in all_options: vid_in_here = sorted(os.listdir(option)) for vid in vid_in_here: vid_count[vid] += 1 all_vid = [] count_to_vid = defaultdict(int) for k, v in vid_count.items(): count_to_vid[v] += 1 all_vid.append(k) for k, v in count_to_vid.items(): print('Videos with count %d: %d' % (k, v)) all_vid = sorted(all_vid) print('Total number of videos: ', len(all_vid)) pool = Pool(processes=args.num_proc) for _ in progressbar(pool.imap_unordered(process_vid, all_vid), max_value=len(all_vid)): pass pool.close() pool.join() if 'D' in args.dataset: print('Making zip for DAVIS test-dev...') shutil.make_archive(args.output, 'zip', args.output) if 'Y' in args.dataset: print('Making zip for YouTubeVOS...') shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations') ================================================ FILE: XMem/merge_results.py ================================================ import glob import os from PIL import Image import numpy as np import tqdm import multiprocessing multi_dir = "mevis_val/vis_output/" outdir = "mevis_val_merge/vis_output/" all_obj_list = [] video_list = glob.glob(os.path.join(multi_dir, "0/*")) for video in video_list: obj_list = glob.glob(video + "/*") all_obj_list = all_obj_list + ['/'.join(i.split('/')[-2:]) for i in obj_list] def merge(obj): obj_output_dir = os.path.join(outdir, obj) os.makedirs(obj_output_dir, exist_ok=True) img_list = [i.split('/')[-1] for i in glob.glob(os.path.join(multi_dir, "0", obj, "*.png"))] for img_name in img_list: agg_img = None for i in range(7): img_path = os.path.join(multi_dir, str(i), obj, img_name) tmp_img = (np.array(Image.open(img_path)) > 0).astype(np.uint8) if agg_img is not None: agg_img = agg_img + tmp_img else: agg_img = tmp_img agg_img = (agg_img >= 4).astype(np.uint8) agg_img = Image.fromarray(agg_img) img_output_path = os.path.join(obj_output_dir, img_name) agg_img.save(img_output_path) print('start') cpu_num = multiprocessing.cpu_count()-1 print("cpu_num:", cpu_num) pool = multiprocessing.Pool(cpu_num) for obj in all_obj_list: pool.apply_async(merge, args=(obj,)) pool.close() pool.join() ================================================ FILE: XMem/model/__init__.py ================================================ ================================================ FILE: XMem/model/aggregate.py ================================================ import torch import torch.nn.functional as F # Soft aggregation from STM def aggregate(prob, dim, return_logits=False): new_prob = torch.cat([ torch.prod(1-prob, dim=dim, keepdim=True), prob ], dim).clamp(1e-7, 1-1e-7) logits = torch.log((new_prob /(1-new_prob))) prob = F.softmax(logits, dim=dim) if return_logits: return logits, prob else: return prob ================================================ FILE: XMem/model/cbam.py ================================================ # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py import torch import torch.nn as nn import torch.nn.functional as F class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) def forward(self, x): x = self.conv(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = torch.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial=no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out ================================================ FILE: XMem/model/group_modules.py ================================================ """ Group-specific modules They handle features that also depends on the mask. Features are typically of shape batch_size * num_objects * num_channels * H * W All of them are permutation equivariant w.r.t. to the num_objects dimension """ import torch import torch.nn as nn import torch.nn.functional as F def interpolate_groups(g, ratio, mode, align_corners): batch_size, num_objects = g.shape[:2] g = F.interpolate(g.flatten(start_dim=0, end_dim=1), scale_factor=ratio, mode=mode, align_corners=align_corners) g = g.view(batch_size, num_objects, *g.shape[1:]) return g def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False): return interpolate_groups(g, ratio, mode, align_corners) def downsample_groups(g, ratio=1/2, mode='area', align_corners=None): return interpolate_groups(g, ratio, mode, align_corners) class GConv2D(nn.Conv2d): def forward(self, g): batch_size, num_objects = g.shape[:2] g = super().forward(g.flatten(start_dim=0, end_dim=1)) return g.view(batch_size, num_objects, *g.shape[1:]) class GroupResBlock(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() if in_dim == out_dim: self.downsample = None else: self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) def forward(self, g): out_g = self.conv1(F.relu(g)) out_g = self.conv2(F.relu(out_g)) if self.downsample is not None: g = self.downsample(g) return out_g + g class MainToGroupDistributor(nn.Module): def __init__(self, x_transform=None, method='cat', reverse_order=False): super().__init__() self.x_transform = x_transform self.method = method self.reverse_order = reverse_order def forward(self, x, g): num_objects = g.shape[1] if self.x_transform is not None: x = self.x_transform(x) if self.method == 'cat': if self.reverse_order: g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2) else: g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2) elif self.method == 'add': g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g else: raise NotImplementedError return g ================================================ FILE: XMem/model/losses.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from collections import defaultdict def dice_loss(input_mask, cls_gt): num_objects = input_mask.shape[1] losses = [] for i in range(num_objects): mask = input_mask[:,i].flatten(start_dim=1) # background not in mask, so we add one to cls_gt gt = (cls_gt==(i+1)).float().flatten(start_dim=1) numerator = 2 * (mask * gt).sum(-1) denominator = mask.sum(-1) + gt.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) losses.append(loss) return torch.cat(losses).mean() # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch class BootstrappedCE(nn.Module): def __init__(self, start_warm, end_warm, top_p=0.15): super().__init__() self.start_warm = start_warm self.end_warm = end_warm self.top_p = top_p def forward(self, input, target, it): if it < self.start_warm: return F.cross_entropy(input, target), 1.0 raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) num_pixels = raw_loss.numel() if it > self.end_warm: this_p = self.top_p else: this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) return loss.mean(), this_p class LossComputer: def __init__(self, config): super().__init__() self.config = config self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) def compute(self, data, num_objects, it): losses = defaultdict(int) b, t = data['rgb'].shape[:2] losses['total_loss'] = 0 for ti in range(1, t): for bi in range(b): loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) losses['p'] += p / b / (t-1) losses[f'ce_loss_{ti}'] += loss / b losses['total_loss'] += losses['ce_loss_%d'%ti] losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) losses['total_loss'] += losses[f'dice_loss_{ti}'] return losses ================================================ FILE: XMem/model/memory_util.py ================================================ import math import numpy as np import torch from typing import Optional def get_similarity(mk, ms, qk, qe): # used for training/inference and memory reading/memory potentiation # mk: B x CK x [N] - Memory keys # ms: B x 1 x [N] - Memory shrinkage # qk: B x CK x [HW/P] - Query keys # qe: B x CK x [HW/P] - Query selection # Dimensions in [] are flattened CK = mk.shape[1] mk = mk.flatten(start_dim=2) ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None qk = qk.flatten(start_dim=2) qe = qe.flatten(start_dim=2) if qe is not None else None if qe is not None: # See appendix for derivation # or you can just trust me ヽ(ー_ー )ノ mk = mk.transpose(1, 2) a_sq = (mk.pow(2) @ qe) two_ab = 2 * (mk @ (qk * qe)) b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) similarity = (-a_sq+two_ab-b_sq) else: # similar to STCN if we don't have the selection term a_sq = mk.pow(2).sum(1).unsqueeze(2) two_ab = 2 * (mk.transpose(1, 2) @ qk) similarity = (-a_sq+two_ab) if ms is not None: similarity = similarity * ms / math.sqrt(CK) # B*N*HW else: similarity = similarity / math.sqrt(CK) # B*N*HW return similarity def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False): # normalize similarity with top-k softmax # similarity: B x N x [HW/P] # use inplace with care if top_k is not None: values, indices = torch.topk(similarity, k=top_k, dim=1) x_exp = values.exp_() x_exp /= torch.sum(x_exp, dim=1, keepdim=True) if inplace: similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW affinity = similarity else: affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW else: maxes = torch.max(similarity, dim=1, keepdim=True)[0] x_exp = torch.exp(similarity - maxes) x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) affinity = x_exp / x_exp_sum indices = None if return_usage: return affinity, affinity.sum(dim=2) return affinity def get_affinity(mk, ms, qk, qe): # shorthand used in training with no top-k similarity = get_similarity(mk, ms, qk, qe) affinity = do_softmax(similarity) return affinity def readout(affinity, mv): B, CV, T, H, W = mv.shape mo = mv.view(B, CV, T*H*W) mem = torch.bmm(mo, affinity) mem = mem.view(B, CV, H, W) return mem ================================================ FILE: XMem/model/modules.py ================================================ """ modules.py - This file stores the rather boring network blocks. x - usually means features that only depends on the image g - usually means features that also depends on the mask. They might have an extra "group" or "num_objects" dimension, hence batch_size * num_objects * num_channels * H * W The trailing number of a variable usually denote the stride """ import torch import torch.nn as nn import torch.nn.functional as F from model.group_modules import * from model import resnet from model.cbam import CBAM class FeatureFusionBlock(nn.Module): def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim): super().__init__() self.distributor = MainToGroupDistributor() self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim) self.attention = CBAM(g_mid_dim) self.block2 = GroupResBlock(g_mid_dim, g_out_dim) def forward(self, x, g): batch_size, num_objects = g.shape[:2] g = self.distributor(x, g) g = self.block1(g) r = self.attention(g.flatten(start_dim=0, end_dim=1)) r = r.view(batch_size, num_objects, *r.shape[1:]) g = self.block2(g+r) return g class HiddenUpdater(nn.Module): # Used in the decoder, multi-scale feature + GRU def __init__(self, g_dims, mid_dim, hidden_dim): super().__init__() self.hidden_dim = hidden_dim self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1) self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1) self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1) self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) nn.init.xavier_normal_(self.transform.weight) def forward(self, g, h): g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ self.g4_conv(downsample_groups(g[2], ratio=1/4)) g = torch.cat([g, h], 2) # defined slightly differently than standard GRU, # namely the new value is generated before the forget gate. # might provide better gradient but frankly it was initially just an # implementation error that I never bothered fixing values = self.transform(g) forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) new_h = forget_gate*h*(1-update_gate) + update_gate*new_value return new_h class HiddenReinforcer(nn.Module): # Used in the value encoder, a single GRU def __init__(self, g_dim, hidden_dim): super().__init__() self.hidden_dim = hidden_dim self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1) nn.init.xavier_normal_(self.transform.weight) def forward(self, g, h): g = torch.cat([g, h], 2) # defined slightly differently than standard GRU, # namely the new value is generated before the forget gate. # might provide better gradient but frankly it was initially just an # implementation error that I never bothered fixing values = self.transform(g) forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim]) update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2]) new_value = torch.tanh(values[:,:,self.hidden_dim*2:]) new_h = forget_gate*h*(1-update_gate) + update_gate*new_value return new_h class ValueEncoder(nn.Module): def __init__(self, value_dim, hidden_dim, single_object=False): super().__init__() self.single_object = single_object network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 self.maxpool = network.maxpool self.layer1 = network.layer1 # 1/4, 64 self.layer2 = network.layer2 # 1/8, 128 self.layer3 = network.layer3 # 1/16, 256 self.distributor = MainToGroupDistributor() self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim) if hidden_dim > 0: self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim) else: self.hidden_reinforce = None def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): # image_feat_f16 is the feature from the key encoder if not self.single_object: g = torch.stack([masks, others], 2) else: g = masks.unsqueeze(2) g = self.distributor(image, g) batch_size, num_objects = g.shape[:2] g = g.flatten(start_dim=0, end_dim=1) g = self.conv1(g) g = self.bn1(g) # 1/2, 64 g = self.maxpool(g) # 1/4, 64 g = self.relu(g) g = self.layer1(g) # 1/4 g = self.layer2(g) # 1/8 g = self.layer3(g) # 1/16 g = g.view(batch_size, num_objects, *g.shape[1:]) g = self.fuser(image_feat_f16, g) if is_deep_update and self.hidden_reinforce is not None: h = self.hidden_reinforce(g, h) return g, h class KeyEncoder(nn.Module): def __init__(self): super().__init__() network = resnet.resnet50(pretrained=True) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 self.maxpool = network.maxpool self.res2 = network.layer1 # 1/4, 256 self.layer2 = network.layer2 # 1/8, 512 self.layer3 = network.layer3 # 1/16, 1024 def forward(self, f): x = self.conv1(f) x = self.bn1(x) x = self.relu(x) # 1/2, 64 x = self.maxpool(x) # 1/4, 64 f4 = self.res2(x) # 1/4, 256 f8 = self.layer2(f4) # 1/8, 512 f16 = self.layer3(f8) # 1/16, 1024 return f16, f8, f4 class UpsampleBlock(nn.Module): def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2): super().__init__() self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1) self.distributor = MainToGroupDistributor(method='add') self.out_conv = GroupResBlock(g_up_dim, g_out_dim) self.scale_factor = scale_factor def forward(self, skip_f, up_g): skip_f = self.skip_conv(skip_f) g = upsample_groups(up_g, ratio=self.scale_factor) g = self.distributor(skip_f, g) g = self.out_conv(g) return g class KeyProjection(nn.Module): def __init__(self, in_dim, keydim): super().__init__() self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) # shrinkage self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1) # selection self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) nn.init.orthogonal_(self.key_proj.weight.data) nn.init.zeros_(self.key_proj.bias.data) def forward(self, x, need_s, need_e): shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None return self.key_proj(x), shrinkage, selection class Decoder(nn.Module): def __init__(self, val_dim, hidden_dim): super().__init__() self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512) if hidden_dim > 0: self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim) else: self.hidden_update = None self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1) def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True): batch_size, num_objects = memory_readout.shape[:2] if self.hidden_update is not None: g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2)) else: g16 = self.fuser(f16, memory_readout) g8 = self.up_16_8(f8, g16) g4 = self.up_8_4(f4, g8) logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1))) if h_out and self.hidden_update is not None: g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2) hidden_state = self.hidden_update([g16, g8, g4], hidden_state) else: hidden_state = None logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) return hidden_state, logits ================================================ FILE: XMem/model/network.py ================================================ """ This file defines XMem, the highest level nn.Module interface During training, it is used by trainer.py During evaluation, it is used by inference_core.py It further depends on modules.py which gives more detailed implementations of sub-modules """ import torch import torch.nn as nn from loguru import logger from model.aggregate import aggregate from model.modules import * from model.memory_util import * class XMem(nn.Module): def __init__(self, config, model_path=None, map_location=None): """ model_path/map_location are used in evaluation only map_location is for converting models saved in cuda to cpu """ super().__init__() model_weights = self.init_hyperparameters(config, model_path, map_location) self.single_object = config.get('single_object', False) print(f'Single object mode: {self.single_object}') self.key_encoder = KeyEncoder() self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object) # Projection from f16 feature space to key/value space self.key_proj = KeyProjection(1024, self.key_dim) self.decoder = Decoder(self.value_dim, self.hidden_dim) if model_weights is not None: self.load_weights(model_weights, init_as_zero_if_needed=True) def encode_key(self, frame, need_sk=True, need_ek=True): # Determine input shape if len(frame.shape) == 5: # shape is b*t*c*h*w need_reshape = True b, t = frame.shape[:2] # flatten so that we can feed them into a 2D CNN frame = frame.flatten(start_dim=0, end_dim=1) elif len(frame.shape) == 4: # shape is b*c*h*w need_reshape = False else: raise NotImplementedError f16, f8, f4 = self.key_encoder(frame) key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek) if need_reshape: # B*C*T*H*W key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous() if shrinkage is not None: shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous() if selection is not None: selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous() # B*T*C*H*W f16 = f16.view(b, t, *f16.shape[-3:]) f8 = f8.view(b, t, *f8.shape[-3:]) f4 = f4.view(b, t, *f4.shape[-3:]) return key, shrinkage, selection, f16, f8, f4 def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True): num_objects = masks.shape[1] if num_objects != 1: others = torch.cat([ torch.sum( masks[:, [j for j in range(num_objects) if i!=j]] , dim=1, keepdim=True) for i in range(num_objects)], 1) else: others = torch.zeros_like(masks) g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update) return g16, h16 # Used in training only. # This step is replaced by MemoryManager in test time def read_memory(self, query_key, query_selection, memory_key, memory_shrinkage, memory_value): """ query_key : B * CK * H * W query_selection : B * CK * H * W memory_key : B * CK * T * H * W memory_shrinkage: B * 1 * T * H * W memory_value : B * num_objects * CV * T * H * W """ batch_size, num_objects = memory_value.shape[:2] memory_value = memory_value.flatten(start_dim=1, end_dim=2) affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection) memory = readout(affinity, memory_value) memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:]) return memory def segment(self, multi_scale_features, memory_readout, hidden_state, selector=None, h_out=True, strip_bg=True): hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out) prob = torch.sigmoid(logits) if selector is not None: prob = prob * selector logits, prob = aggregate(prob, dim=1, return_logits=True) if strip_bg: # Strip away the background prob = prob[:, 1:] return hidden_state, logits, prob def forward(self, mode, *args, **kwargs): if mode == 'encode_key': return self.encode_key(*args, **kwargs) elif mode == 'encode_value': return self.encode_value(*args, **kwargs) elif mode == 'read_memory': return self.read_memory(*args, **kwargs) elif mode == 'segment': return self.segment(*args, **kwargs) else: raise NotImplementedError @logger.catch() def init_hyperparameters(self, config, model_path=None, map_location=None): """ Init three hyperparameters: key_dim, value_dim, and hidden_dim If model_path is provided, we load these from the model weights The actual parameters are then updated to the config in-place Otherwise we load it either from the config or default """ if model_path is not None: # load the model and key/value/hidden dimensions with some hacks # config is updated with the loaded parameters model_weights = torch.load(model_path, map_location=map_location) self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0] self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0] self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights if self.disable_hidden: self.hidden_dim = 0 else: self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3 print(f'Hyperparameters read from the model weights: ' f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}') else: model_weights = None # load dimensions from config or default if 'key_dim' not in config: self.key_dim = 64 print(f'key_dim not found in config. Set to default {self.key_dim}') else: self.key_dim = config['key_dim'] if 'value_dim' not in config: self.value_dim = 512 print(f'value_dim not found in config. Set to default {self.value_dim}') else: self.value_dim = config['value_dim'] if 'hidden_dim' not in config: self.hidden_dim = 64 print(f'hidden_dim not found in config. Set to default {self.hidden_dim}') else: self.hidden_dim = config['hidden_dim'] self.disable_hidden = (self.hidden_dim <= 0) config['key_dim'] = self.key_dim config['value_dim'] = self.value_dim config['hidden_dim'] = self.hidden_dim return model_weights @logger.catch() def load_weights(self, src_dict, init_as_zero_if_needed=False): # Maps SO weight (without other_mask) to MO weight (with other_mask) for k in list(src_dict.keys()): if k == 'value_encoder.conv1.weight': if src_dict[k].shape[1] == 4: print('Converting weights from single object to multiple objects.') pads = torch.zeros((64,1,7,7), device=src_dict[k].device) if not init_as_zero_if_needed: print('Randomly initialized padding.') nn.init.orthogonal_(pads) else: print('Zero-initialized padding.') src_dict[k] = torch.cat([src_dict[k], pads], 1) self.load_state_dict(src_dict) ================================================ FILE: XMem/model/resnet.py ================================================ """ resnet.py - A modified ResNet structure We append extra channels to the first conv by some network surgery """ from collections import OrderedDict import math import torch import torch.nn as nn from torch.utils import model_zoo def load_weights_add_extra_dim(target, source_state, extra_dim=1): new_dict = OrderedDict() for k1, v1 in target.state_dict().items(): if not 'num_batches_tracked' in k1: if k1 in source_state: tar_v = source_state[k1] if v1.shape != tar_v.shape: # Init the new segmentation channel with zeros # print(v1.shape, tar_v.shape) c, _, w, h = v1.shape pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device) nn.init.orthogonal_(pads) tar_v = torch.cat([tar_v, pads], 1) new_dict[k1] = tar_v target.load_state_dict(new_dict) model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', } def conv3x3(in_planes, out_planes, stride=1, dilation=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation)) return nn.Sequential(*layers) def resnet18(pretrained=True, extra_dim=0): model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) if pretrained: load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) return model def resnet50(pretrained=True, extra_dim=0): model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) if pretrained: load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) return model ================================================ FILE: XMem/model/trainer.py ================================================ """ trainer.py - warpper and utility functions for network training Compute loss, back-prop, update parameters, logging, etc. """ import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from model.network import XMem from model.losses import LossComputer from util.log_integrator import Integrator from util.image_saver import pool_pairs class XMemTrainer: def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1): self.config = config self.num_frames = config['num_frames'] self.num_ref_frames = config['num_ref_frames'] self.deep_update_prob = config['deep_update_prob'] self.local_rank = local_rank self.XMem = nn.parallel.DistributedDataParallel( XMem(config).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Set up logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()]))) self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) self.loss_computer = LossComputer(config) self.train() self.optimizer = optim.AdamW(filter( lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay']) self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma']) if config['amp']: self.scaler = torch.cuda.amp.GradScaler() # Logging info self.log_text_interval = config['log_text_interval'] self.log_image_interval = config['log_image_interval'] self.save_network_interval = config['save_network_interval'] self.save_checkpoint_interval = config['save_checkpoint_interval'] if config['debug']: self.log_text_interval = self.log_image_interval = 1 def do_pass(self, data, it=0): # No need to store the gradient outside training torch.set_grad_enabled(self._is_train) for k, v in data.items(): if type(v) != list and type(v) != dict and type(v) != int: data[k] = v.cuda(non_blocking=True) out = {} frames = data['rgb'] first_frame_gt = data['first_frame_gt'].float() b = frames.shape[0] num_filled_objects = [o.item() for o in data['info']['num_objects']] num_objects = first_frame_gt.shape[2] selector = data['selector'].unsqueeze(2).unsqueeze(2) with torch.cuda.amp.autocast(enabled=self.config['amp']): # image features never change, compute once key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames) filler_one = torch.zeros(1, dtype=torch.int64) hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:])) v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0]) values = v16.unsqueeze(3) # add the time dimension for ti in range(1, self.num_frames): if ti <= self.num_ref_frames: ref_values = values ref_keys = key[:,:,:ti] ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None else: # pick num_ref_frames random frames # this is not very efficient but I think we would # need broadcasting in gather which we don't have indices = [ torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1]) for _ in range(b)] ref_values = torch.stack([ values[bi, :, :, indices[bi]] for bi in range(b) ], 0) ref_keys = torch.stack([ key[bi, :, indices[bi]] for bi in range(b) ], 0) ref_shrinkage = torch.stack([ shrinkage[bi, :, indices[bi]] for bi in range(b) ], 0) if shrinkage is not None else None # Segment frame ti memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None, ref_keys, ref_shrinkage, ref_values) hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout, hidden, selector, h_out=(ti < (self.num_frames-1))) # No need to encode the last frame if ti < (self.num_frames-1): is_deep_update = np.random.rand() < self.deep_update_prob v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update) values = torch.cat([values, v16.unsqueeze(3)], 3) out[f'masks_{ti}'] = masks out[f'logits_{ti}'] = logits if self._do_log or self._is_train: losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it) # Logging if self._do_log: self.integrator.add_dict(losses) if self._is_train: if it % self.log_image_interval == 0 and it != 0: if self.logger is not None: images = {**data, **out} size = (384, 384) self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it) if self._is_train: if (it) % self.log_text_interval == 0 and it != 0: if self.logger is not None: self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.log_text_interval, it) self.last_time = time.time() self.train_integrator.finalize('train', it) self.train_integrator.reset_except_hooks() if it % self.save_network_interval == 0 and it != 0: if self.logger is not None: self.save_network(it) if it % self.save_checkpoint_interval == 0 and it != 0: if self.logger is not None: self.save_checkpoint(it) # Backward pass self.optimizer.zero_grad(set_to_none=True) if self.config['amp']: self.scaler.scale(losses['total_loss']).backward() self.scaler.step(self.optimizer) self.scaler.update() else: losses['total_loss'].backward() self.optimizer.step() self.scheduler.step() def save_network(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) model_path = f'{self.save_path}_{it}.pth' torch.save(self.XMem.module.state_dict(), model_path) print(f'Network saved to {model_path}.') def save_checkpoint(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth' checkpoint = { 'it': it, 'network': self.XMem.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict()} torch.save(checkpoint, checkpoint_path) print(f'Checkpoint saved to {checkpoint_path}.') def load_checkpoint(self, path): # This method loads everything and should be used to resume training map_location = 'cuda:%d' % self.local_rank checkpoint = torch.load(path, map_location={'cuda:0': map_location}) it = checkpoint['it'] network = checkpoint['network'] optimizer = checkpoint['optimizer'] scheduler = checkpoint['scheduler'] map_location = 'cuda:%d' % self.local_rank self.XMem.module.load_state_dict(network) self.optimizer.load_state_dict(optimizer) self.scheduler.load_state_dict(scheduler) print('Network weights, optimizer states, and scheduler states loaded.') return it def load_network_in_memory(self, src_dict): self.XMem.module.load_weights(src_dict) print('Network weight loaded from memory.') def load_network(self, path): # This method loads only the network weight and should be used to load a pretrained model map_location = 'cuda:%d' % self.local_rank src_dict = torch.load(path, map_location={'cuda:0': map_location}) self.load_network_in_memory(src_dict) print(f'Network weight loaded from {path}') def train(self): self._is_train = True self._do_log = True self.integrator = self.train_integrator self.XMem.eval() return self def val(self): self._is_train = False self._do_log = True self.XMem.eval() return self def test(self): self._is_train = False self._do_log = False self.XMem.eval() return self ================================================ FILE: XMem/requirements.txt ================================================ progressbar2 gdown hickle tensorboard numpy ================================================ FILE: XMem/scripts/__init__.py ================================================ ================================================ FILE: XMem/scripts/download_bl30k.py ================================================ import os import gdown import tarfile LICENSE = """ This dataset is a derivative of ShapeNet. Please read and respect their licenses and terms before use. Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag. Do not use this dataset for commercial purposes. You should cite both ShapeNet and our paper if you use this dataset. """ print(LICENSE) print('Datasets will be downloaded and extracted to ../BL30K') print('The script will download and extract the segment one by one') print('You are going to need ~1TB of free disk space') reply = input('[y] to confirm, others to exit: ') if reply != 'y': exit() links = [ 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc', 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ', 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q', 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-', 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE', 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r', ] names = [ 'BL30K_a.tar', 'BL30K_b.tar', 'BL30K_c.tar', 'BL30K_d.tar', 'BL30K_e.tar', 'BL30K_f.tar', ] for i, link in enumerate(links): print('Downloading segment %d/%d ...' % (i, len(links))) gdown.download(link, output='../%s' % names[i], quiet=False) print('Extracting...') with tarfile.open('../%s' % names[i], 'r') as tar_file: tar_file.extractall('../%s' % names[i]) print('Cleaning up...') os.remove('../%s' % names[i]) print('Done.') ================================================ FILE: XMem/scripts/download_datasets.py ================================================ import os import gdown import zipfile from scripts import resize_youtube LICENSE = """ These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets. Please read and respect their licenses and terms before use. You should cite the original papers if you use any of the datasets. For BL30K, see download_bl30k.py Links: DUTS: http://saliencydetection.net/duts HRSOD: https://github.com/yi94code/HRSOD FSS: https://github.com/HKUSTCV/FSS-1000 ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html BIG: https://github.com/hkchengrex/CascadePSP YouTubeVOS: https://youtube-vos.org DAVIS: https://davischallenge.org/ BL30K: https://github.com/hkchengrex/MiVOS Long-Time Video: https://github.com/xmlyqing00/AFB-URR """ print(LICENSE) print('Datasets will be downloaded and extracted to ../YouTube, ../YouTube2018, ../static, ../DAVIS, ../long_video_set') reply = input('[y] to confirm, others to exit: ') if reply != 'y': exit() """ Static image data """ os.makedirs('../static', exist_ok=True) print('Downloading static datasets...') gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', output='../static/static_data.zip', quiet=False) print('Extracting static datasets...') with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file: zip_file.extractall('../static/') print('Cleaning up static datasets...') os.remove('../static/static_data.zip') """ DAVIS dataset """ # Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing os.makedirs('../DAVIS/2017', exist_ok=True) print('Downloading DAVIS 2016...') gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', output='../DAVIS/DAVIS-data.zip', quiet=False) print('Downloading DAVIS 2017 trainval...') gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', quiet=False) print('Downloading DAVIS 2017 testdev...') gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', quiet=False) print('Downloading DAVIS 2017 scribbles...') gdown.download('https://drive.google.com/uc?id=1JzIQSu36h7dVM8q0VoE4oZJwBXvrZlkl', output='../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', quiet=False) print('Extracting DAVIS datasets...') with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file: zip_file.extractall('../DAVIS/') os.rename('../DAVIS/DAVIS', '../DAVIS/2016') with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file: zip_file.extractall('../DAVIS/2017/') with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', 'r') as zip_file: zip_file.extractall('../DAVIS/2017/') os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval') with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file: zip_file.extractall('../DAVIS/2017/') os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev') print('Cleaning up DAVIS datasets...') os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip') os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip') os.remove('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip') os.remove('../DAVIS/DAVIS-data.zip') """ YouTubeVOS dataset """ os.makedirs('../YouTube', exist_ok=True) os.makedirs('../YouTube/all_frames', exist_ok=True) print('Downloading YouTubeVOS train...') gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', output='../YouTube/train.zip', quiet=False) print('Downloading YouTubeVOS val...') gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', output='../YouTube/valid.zip', quiet=False) print('Downloading YouTubeVOS all frames valid...') gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', output='../YouTube/all_frames/valid.zip', quiet=False) print('Extracting YouTube datasets...') with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file: zip_file.extractall('../YouTube/') with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file: zip_file.extractall('../YouTube/') with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file: zip_file.extractall('../YouTube/all_frames') print('Cleaning up YouTubeVOS datasets...') os.remove('../YouTube/train.zip') os.remove('../YouTube/valid.zip') os.remove('../YouTube/all_frames/valid.zip') print('Resizing YouTubeVOS to 480p...') resize_youtube.resize_all('../YouTube/train', '../YouTube/train_480p') # YouTubeVOS 2018 os.makedirs('../YouTube2018', exist_ok=True) os.makedirs('../YouTube2018/all_frames', exist_ok=True) print('Downloading YouTubeVOS2018 val...') gdown.download('https://drive.google.com/uc?id=1-QrceIl5sUNTKz7Iq0UsWC6NLZq7girr', output='../YouTube2018/valid.zip', quiet=False) print('Downloading YouTubeVOS2018 all frames valid...') gdown.download('https://drive.google.com/uc?id=1yVoHM6zgdcL348cFpolFcEl4IC1gorbV', output='../YouTube2018/all_frames/valid.zip', quiet=False) print('Extracting YouTube2018 datasets...') with zipfile.ZipFile('../YouTube2018/valid.zip', 'r') as zip_file: zip_file.extractall('../YouTube2018/') with zipfile.ZipFile('../YouTube2018/all_frames/valid.zip', 'r') as zip_file: zip_file.extractall('../YouTube2018/all_frames') print('Cleaning up YouTubeVOS2018 datasets...') os.remove('../YouTube2018/valid.zip') os.remove('../YouTube2018/all_frames/valid.zip') """ Long-Time Video dataset """ os.makedirs('../long_video_set', exist_ok=True) print('Downloading long video dataset...') gdown.download('https://drive.google.com/uc?id=100MxAuV0_UL20ca5c-5CNpqQ5QYPDSoz', output='../long_video_set/LongTimeVideo.zip', quiet=False) print('Extracting long video dataset...') with zipfile.ZipFile('../long_video_set/LongTimeVideo.zip', 'r') as zip_file: zip_file.extractall('../long_video_set/') print('Cleaning up long video dataset...') os.remove('../long_video_set/LongTimeVideo.zip') print('Done.') ================================================ FILE: XMem/scripts/download_models.sh ================================================ wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth ================================================ FILE: XMem/scripts/download_models_demo.sh ================================================ wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/fbrs.pth wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/s2m.pth ================================================ FILE: XMem/scripts/expand_long_vid.py ================================================ import sys import os from os import path from shutil import copy2 input_path = sys.argv[1] output_path = sys.argv[2] multiplier = int(sys.argv[3]) image_path = path.join(input_path, 'JPEGImages') gt_path = path.join(input_path, 'Annotations') videos = sorted(os.listdir(image_path)) for vid in videos: os.makedirs(path.join(output_path, 'JPEGImages', vid), exist_ok=True) os.makedirs(path.join(output_path, 'Annotations', vid), exist_ok=True) frames = sorted(os.listdir(path.join(image_path, vid))) num_frames = len(frames) counter = 0 output_counter = 0 direction = 1 for _ in range(multiplier): for _ in range(num_frames): copy2(path.join(image_path, vid, frames[counter]), path.join(output_path, 'JPEGImages', vid, f'{output_counter:05d}.jpg')) mask_path = path.join(gt_path, vid, frames[counter].replace('.jpg', '.png')) if path.exists(mask_path): copy2(mask_path, path.join(output_path, 'Annotations', vid, f'{output_counter:05d}.png')) counter += direction output_counter += 1 if counter == 0 or counter == len(frames) - 1: direction *= -1 ================================================ FILE: XMem/scripts/resize_youtube.py ================================================ import sys import os from os import path from PIL import Image import numpy as np from progressbar import progressbar from multiprocessing import Pool new_min_size = 480 def resize_vid_jpeg(inputs): vid_name, folder_path, out_path = inputs vid_path = path.join(folder_path, vid_name) vid_out_path = path.join(out_path, 'JPEGImages', vid_name) os.makedirs(vid_out_path, exist_ok=True) for im_name in os.listdir(vid_path): hr_im = Image.open(path.join(vid_path, im_name)) w, h = hr_im.size ratio = new_min_size / min(w, h) lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) lr_im.save(path.join(vid_out_path, im_name)) def resize_vid_anno(inputs): vid_name, folder_path, out_path = inputs vid_path = path.join(folder_path, vid_name) vid_out_path = path.join(out_path, 'Annotations', vid_name) os.makedirs(vid_out_path, exist_ok=True) for im_name in os.listdir(vid_path): hr_im = Image.open(path.join(vid_path, im_name)).convert('P') w, h = hr_im.size ratio = new_min_size / min(w, h) lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST) lr_im.save(path.join(vid_out_path, im_name)) def resize_all(in_path, out_path): for folder in os.listdir(in_path): if folder not in ['JPEGImages', 'Annotations']: continue folder_path = path.join(in_path, folder) videos = os.listdir(folder_path) videos = [(v, folder_path, out_path) for v in videos] if folder == 'JPEGImages': print('Processing images') os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True) pool = Pool(processes=8) for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)): pass else: print('Processing annotations') os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True) pool = Pool(processes=8) for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)): pass if __name__ == '__main__': in_path = sys.argv[1] out_path = sys.argv[2] resize_all(in_path, out_path) print('Done.') ================================================ FILE: XMem/tracking.py ================================================ import sys sys.path.insert(0, './XMem') import os import os.path as osp import glob import cv2 import json import argparse import multiprocessing as mp from tqdm import tqdm from termcolor import colored from importlib.util import find_spec if find_spec("GPUtil") is None: os.system("pip install gputil") import GPUtil _GPU_LIST = [_.id for _ in GPUtil.getGPUs()] _GPU_QUEUE = mp.Queue() for _ in _GPU_LIST: _GPU_QUEUE.put(_) def run_eval(meta_expression, temp_xmem_anno, final_xmem_anno, img_dir, split_part, xmem_weight, cfgs=" --reversed ", ): gpu_id = _GPU_QUEUE.get() cmd = f"cd XMem && CUDA_VISIBLE_DEVICES={gpu_id} python eval.py --meta_exp {meta_expression} --output {final_xmem_anno} --generic_path {temp_xmem_anno} --img_dir {img_dir} --split_part {split_part} --model {xmem_weight} --dataset G {cfgs}" print(f"Running: {cmd}") os.system(cmd) _GPU_QUEUE.put(gpu_id) def generate(obj, temp_xmem_anno, final_xmem_anno): obj_dir, video_name, obj_id, tp = obj img_list = glob.glob(obj_dir + '/*.png') # Mask img_list.sort() frame_id = int(len(img_list) * tp) if frame_id == len(img_list): frame_id -= 1 used_img = img_list[frame_id] img_output_path = osp.join(temp_xmem_anno, video_name, obj_id, osp.basename(used_img)) final_img_output_dir = osp.join(final_xmem_anno, video_name, obj_id) img_output_dir = osp.dirname(img_output_path) os.makedirs(img_output_dir, exist_ok=True) os.makedirs(final_img_output_dir, exist_ok=True) os.system('cp {} {}'.format(used_img, img_output_path)) img = cv2.imread(img_output_path) if img.sum() == 0: target_img_list = [i.split('/')[-1] for i in img_list] for img_ in target_img_list: print(os.path.join(final_img_output_dir, img_)) os.system('cp {} {}'.format(img_output_path, os.path.join(img_output_dir, img_))) os.system('cp {} {}'.format(img_output_path, os.path.join(final_img_output_dir, img_))) return 0 def prepare(args): video_root = args.video_root temp_xmem_anno = args.temp_xmem_anno final_xmem_anno = args.final_xmem_anno os.makedirs(temp_xmem_anno, exist_ok=True) data = json.load(open(args.llama_vid_meta, 'r'))['videos'] all_obj_list = [] for video_name in data.keys(): exps = data[video_name]['expressions'] for obj_id in exps.keys(): tp = exps[obj_id]['tp'] obj_dir = os.path.join(video_root, video_name, obj_id) all_obj_list.append([obj_dir, video_name, obj_id, tp]) print('start') cpu_num = mp.cpu_count()-1 print("cpu_num:", cpu_num) pool = mp.Pool(cpu_num) pbar = tqdm(total=len(all_obj_list)) for obj in all_obj_list: pool.apply_async( generate, args = (obj, temp_xmem_anno, final_xmem_anno ), callback = lambda *a: pbar.update(1), error_callback = lambda e: print(colored(e, "red")) ) pool.close() pool.join() pbar.close() def inference(args): p = mp.Pool(8) for split_part in [0, 1, 2, 3]: for cfgs in [" ", " --reversed "]: p.apply_async( run_eval, args=(args.llama_vid_meta, args.temp_xmem_anno, args.final_xmem_anno, args.img_dir, split_part, args.xmem_weight, cfgs), error_callback=lambda e: print(colored(e, "red")) ) p.close() p.join() """ python XMem/tracking.py \ --video_root /mnt/public03/dataset/ovis/rgvos/visa7b/val_7b/revos_valid/Annotations \ --temp_xmem_anno /mnt/public03/dataset/ovis/rgvos/visa7b/val_7b/revos_valid/revos_valid_XMem_temp/Annotations \ --final_xmem_anno /mnt/public03/dataset/ovis/rgvos/visa7b/val_7b/revos_valid/revos_valid_XMem_final/Annotations \ --llama_vid_meta /mnt/public02/usr/yancilin/clyan_data/other-datasets/ReVOS/meta_expressions_valid__llamavid.json \ --img_dir /mnt/public02/usr/yancilin/clyan_data/other-datasets/ReVOS/JPEGImages \ --xmem_weight /mnt/public02/usr/yancilin/VISA/XMem/weights/XMem.pth """ def main(): parser = argparse.ArgumentParser(description='rgvos') parser.add_argument('--video_root', type=str, help='/PATH/TO/VISA_exp/revos_valid/Annotations', ) parser.add_argument('--temp_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations', ) # 保存单帧 Mask 的路径 parser.add_argument('--final_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations', ) # 保存 XMem 最后输出结果的路径 parser.add_argument("--llama_vid_meta", type=str, help='/PATH/TO/ReVOS/meta_expressions_valid__llamavid.json', ) parser.add_argument("--img_dir", type=str, help='/PATH/TO/ReVOS/JPEGImages') parser.add_argument("--xmem_weight", type=str, help='/PATH/TO/XMEM_WEIGHT') args = parser.parse_args() prepare(args) inference(args) print('Done.') if __name__ == '__main__': main() ================================================ FILE: XMem/train.py ================================================ import datetime from os import path import math import git import random import numpy as np import torch from torch.utils.data import DataLoader, ConcatDataset import torch.distributed as distributed from model.trainer import XMemTrainer from dataset.static_dataset import StaticTransformDataset from dataset.vos_dataset import VOSDataset from util.logger import TensorboardLogger from util.configuration import Configuration from util.load_subset import load_sub_davis, load_sub_yv """ Initial setup """ # Init distributed environment distributed.init_process_group(backend="nccl") print(f'CUDA Device count: {torch.cuda.device_count()}') # Parse command line arguments raw_config = Configuration() raw_config.parse() if raw_config['benchmark']: torch.backends.cudnn.benchmark = True # Get current git info repo = git.Repo(".") git_info = str(repo.active_branch)+' '+str(repo.head.commit.hexsha) local_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() torch.cuda.set_device(local_rank) print(f'I am rank {local_rank} in this world of size {world_size}!') network_in_memory = None stages = raw_config['stages'] stages_to_perform = list(stages) for si, stage in enumerate(stages_to_perform): # Set seed to ensure the same initialization torch.manual_seed(14159265) np.random.seed(14159265) random.seed(14159265) # Pick stage specific hyperparameters out stage_config = raw_config.get_stage_parameters(stage) config = dict(**raw_config.args, **stage_config) if config['exp_id'] != 'NULL': config['exp_id'] = config['exp_id']+'_s%s'%stages[:si+1] config['single_object'] = (stage == '0') config['num_gpus'] = world_size if config['batch_size']//config['num_gpus']*config['num_gpus'] != config['batch_size']: raise ValueError('Batch size must be divisible by the number of GPUs.') config['batch_size'] //= config['num_gpus'] config['num_workers'] //= config['num_gpus'] print(f'We are assuming {config["num_gpus"]} GPUs.') print(f'We are now starting stage {stage}') """ Model related """ if local_rank == 0: # Logging if config['exp_id'].lower() != 'null': print('I will take the role of logging!') long_id = '%s_%s' % (datetime.datetime.now().strftime('%b%d_%H.%M.%S'), config['exp_id']) else: long_id = None logger = TensorboardLogger(config['exp_id'], long_id, git_info) logger.log_string('hyperpara', str(config)) # Construct the rank 0 model model = XMemTrainer(config, logger=logger, save_path=path.join('saves', long_id, long_id) if long_id is not None else None, local_rank=local_rank, world_size=world_size).train() else: # Construct model for other ranks model = XMemTrainer(config, local_rank=local_rank, world_size=world_size).train() # Load pertrained model if needed if raw_config['load_checkpoint'] is not None: total_iter = model.load_checkpoint(raw_config['load_checkpoint']) raw_config['load_checkpoint'] = None print('Previously trained model loaded!') else: total_iter = 0 if network_in_memory is not None: print('I am loading network from the previous stage') model.load_network_in_memory(network_in_memory) network_in_memory = None elif raw_config['load_network'] is not None: print('I am loading network from a disk, as listed in configuration') model.load_network(raw_config['load_network']) raw_config['load_network'] = None """ Dataloader related """ # To re-seed the randomness everytime we start a worker def worker_init_fn(worker_id): worker_seed = torch.initial_seed()%(2**31) + worker_id + local_rank*100 np.random.seed(worker_seed) random.seed(worker_seed) def construct_loader(dataset): train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=local_rank, shuffle=True) train_loader = DataLoader(dataset, config['batch_size'], sampler=train_sampler, num_workers=config['num_workers'], worker_init_fn=worker_init_fn, drop_last=True) return train_sampler, train_loader def renew_vos_loader(max_skip, finetune=False): # //5 because we only have annotation for every five frames yv_dataset = VOSDataset(path.join(yv_root, 'JPEGImages'), path.join(yv_root, 'Annotations'), max_skip//5, is_bl=False, subset=load_sub_yv(), num_frames=config['num_frames'], finetune=finetune) davis_dataset = VOSDataset(path.join(davis_root, 'JPEGImages', '480p'), path.join(davis_root, 'Annotations', '480p'), max_skip, is_bl=False, subset=load_sub_davis(), num_frames=config['num_frames'], finetune=finetune) train_dataset = ConcatDataset([davis_dataset]*5 + [yv_dataset]) print(f'YouTube dataset size: {len(yv_dataset)}') print(f'DAVIS dataset size: {len(davis_dataset)}') print(f'Concat dataset size: {len(train_dataset)}') print(f'Renewed with {max_skip=}') return construct_loader(train_dataset) def renew_bl_loader(max_skip, finetune=False): train_dataset = VOSDataset(path.join(bl_root, 'JPEGImages'), path.join(bl_root, 'Annotations'), max_skip, is_bl=True, num_frames=config['num_frames'], finetune=finetune) print(f'Blender dataset size: {len(train_dataset)}') print(f'Renewed with {max_skip=}') return construct_loader(train_dataset) """ Dataset related """ """ These define the training schedule of the distance between frames We will switch to max_skip_values[i] once we pass the percentage specified by increase_skip_fraction[i] Not effective for stage 0 training The initial value is not listed here but in renew_vos_loader(X) """ max_skip_values = [10, 15, 5, 5] if stage == '0': static_root = path.expanduser(config['static_root']) # format: path, method (style of storing images), mutliplier train_dataset = StaticTransformDataset( [ (path.join(static_root, 'fss'), 0, 1), (path.join(static_root, 'DUTS-TR'), 1, 1), (path.join(static_root, 'DUTS-TE'), 1, 1), (path.join(static_root, 'ecssd'), 1, 1), (path.join(static_root, 'BIG_small'), 1, 5), (path.join(static_root, 'HRSOD_small'), 1, 5), ], num_frames=config['num_frames']) train_sampler, train_loader = construct_loader(train_dataset) print(f'Static dataset size: {len(train_dataset)}') elif stage == '1': increase_skip_fraction = [0.1, 0.3, 0.8, 100] bl_root = path.join(path.expanduser(config['bl_root'])) train_sampler, train_loader = renew_bl_loader(5) renew_loader = renew_bl_loader else: # stage 2 or 3 increase_skip_fraction = [0.1, 0.3, 0.9, 100] # VOS dataset, 480p is used for both datasets yv_root = path.join(path.expanduser(config['yv_root']), 'train_480p') davis_root = path.join(path.expanduser(config['davis_root']), '2017', 'trainval') train_sampler, train_loader = renew_vos_loader(5) renew_loader = renew_vos_loader """ Determine max epoch """ total_epoch = math.ceil(config['iterations']/len(train_loader)) current_epoch = total_iter // len(train_loader) print(f'We approximately use {total_epoch} epochs.') if stage != '0': change_skip_iter = [round(config['iterations']*f) for f in increase_skip_fraction] # Skip will only change after an epoch, not in the middle print(f'The skip value will change approximately at the following iterations: {change_skip_iter[:-1]}') """ Starts training """ finetuning = False # Need this to select random bases in different workers np.random.seed(np.random.randint(2**30-1) + local_rank*100) try: while total_iter < config['iterations'] + config['finetune']: # Crucial for randomness! train_sampler.set_epoch(current_epoch) current_epoch += 1 print(f'Current epoch: {current_epoch}') # Train loop model.train() for data in train_loader: # Update skip if needed if stage!='0' and total_iter >= change_skip_iter[0]: while total_iter >= change_skip_iter[0]: cur_skip = max_skip_values[0] max_skip_values = max_skip_values[1:] change_skip_iter = change_skip_iter[1:] print(f'Changing skip to {cur_skip=}') train_sampler, train_loader = renew_loader(cur_skip) break # fine-tune means fewer augmentations to train the sensory memory if config['finetune'] > 0 and not finetuning and total_iter >= config['iterations']: train_sampler, train_loader = renew_loader(cur_skip, finetune=True) finetuning = True model.save_network_interval = 1000 break model.do_pass(data, total_iter) total_iter += 1 if total_iter >= config['iterations'] + config['finetune']: break finally: if not config['debug'] and model.logger is not None and total_iter>5000: model.save_network(total_iter) model.save_checkpoint(total_iter) network_in_memory = model.XMem.module.state_dict() distributed.destroy_process_group() ================================================ FILE: XMem/util/__init__.py ================================================ ================================================ FILE: XMem/util/configuration.py ================================================ from argparse import ArgumentParser def none_or_default(x, default): return x if x is not None else default class Configuration(): def parse(self, unknown_arg_ok=False): parser = ArgumentParser() # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment parser.add_argument('--benchmark', action='store_true') parser.add_argument('--no_amp', action='store_true') # Data parameters parser.add_argument('--static_root', help='Static training data root', default='../static') parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS') parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16) parser.add_argument('--key_dim', default=64, type=int) parser.add_argument('--value_dim', default=512, type=int) parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int) parser.add_argument('--deep_update_prob', default=0.2, type=float) parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02') """ Stage-specific learning parameters Batch sizes are effective -- you don't have to scale them when you scale the number processes """ # Stage 0, static images parser.add_argument('--s0_batch_size', default=16, type=int) parser.add_argument('--s0_iterations', default=150000, type=int) parser.add_argument('--s0_finetune', default=0, type=int) parser.add_argument('--s0_steps', nargs="*", default=[], type=int) parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float) parser.add_argument('--s0_num_ref_frames', default=2, type=int) parser.add_argument('--s0_num_frames', default=3, type=int) parser.add_argument('--s0_start_warm', default=20000, type=int) parser.add_argument('--s0_end_warm', default=70000, type=int) # Stage 1, BL30K parser.add_argument('--s1_batch_size', default=8, type=int) parser.add_argument('--s1_iterations', default=250000, type=int) # fine-tune means fewer augmentations to train the sensory memory parser.add_argument('--s1_finetune', default=0, type=int) parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int) parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float) parser.add_argument('--s1_num_ref_frames', default=3, type=int) parser.add_argument('--s1_num_frames', default=8, type=int) parser.add_argument('--s1_start_warm', default=20000, type=int) parser.add_argument('--s1_end_warm', default=70000, type=int) # Stage 2, DAVIS+YoutubeVOS, longer parser.add_argument('--s2_batch_size', default=8, type=int) parser.add_argument('--s2_iterations', default=150000, type=int) # fine-tune means fewer augmentations to train the sensory memory parser.add_argument('--s2_finetune', default=10000, type=int) parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int) parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float) parser.add_argument('--s2_num_ref_frames', default=3, type=int) parser.add_argument('--s2_num_frames', default=8, type=int) parser.add_argument('--s2_start_warm', default=20000, type=int) parser.add_argument('--s2_end_warm', default=70000, type=int) # Stage 3, DAVIS+YoutubeVOS, shorter parser.add_argument('--s3_batch_size', default=8, type=int) parser.add_argument('--s3_iterations', default=100000, type=int) # fine-tune means fewer augmentations to train the sensory memory parser.add_argument('--s3_finetune', default=10000, type=int) parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int) parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float) parser.add_argument('--s3_num_ref_frames', default=3, type=int) parser.add_argument('--s3_num_frames', default=8, type=int) parser.add_argument('--s3_start_warm', default=20000, type=int) parser.add_argument('--s3_end_warm', default=70000, type=int) parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float) parser.add_argument('--weight_decay', default=0.05, type=float) # Loading parser.add_argument('--load_network', help='Path to pretrained network weight only') parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such') # Logging information parser.add_argument('--log_text_interval', default=100, type=int) parser.add_argument('--log_image_interval', default=1000, type=int) parser.add_argument('--save_network_interval', default=25000, type=int) parser.add_argument('--save_checkpoint_interval', default=50000, type=int) parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL') parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true') # # Multiprocessing parameters, not set by users # parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process') if unknown_arg_ok: args, _ = parser.parse_known_args() self.args = vars(args) else: self.args = vars(parser.parse_args()) self.args['amp'] = not self.args['no_amp'] # check if the stages are valid stage_to_perform = list(self.args['stages']) for s in stage_to_perform: if s not in ['0', '1', '2', '3']: raise NotImplementedError def get_stage_parameters(self, stage): parameters = { 'batch_size': self.args['s%s_batch_size'%stage], 'iterations': self.args['s%s_iterations'%stage], 'finetune': self.args['s%s_finetune'%stage], 'steps': self.args['s%s_steps'%stage], 'lr': self.args['s%s_lr'%stage], 'num_ref_frames': self.args['s%s_num_ref_frames'%stage], 'num_frames': self.args['s%s_num_frames'%stage], 'start_warm': self.args['s%s_start_warm'%stage], 'end_warm': self.args['s%s_end_warm'%stage], } return parameters def __getitem__(self, key): return self.args[key] def __setitem__(self, key, value): self.args[key] = value def __str__(self): return str(self.args) ================================================ FILE: XMem/util/davis_subset.txt ================================================ bear bmx-bumps boat boxing-fisheye breakdance-flare bus car-turn cat-girl classic-car color-run crossing dance-jump dancing disc-jockey dog-agility dog-gooses dogs-scale drift-turn drone elephant flamingo hike hockey horsejump-low kid-football kite-walk koala lady-running lindy-hop longboard lucia mallard-fly mallard-water miami-surf motocross-bumps motorbike night-race paragliding planes-water rallye rhino rollerblade schoolgirls scooter-board scooter-gray sheep skate-park snowboard soccerball stroller stunt surf swing tennis tractor-sand train tuk-tuk upside-down varanus-cage walking ================================================ FILE: XMem/util/image_saver.py ================================================ import cv2 import numpy as np import torch from dataset.range_transform import inv_im_trans from collections import defaultdict def tensor_to_numpy(image): image_np = (image.numpy() * 255).astype('uint8') return image_np def tensor_to_np_float(image): image_np = image.numpy().astype('float32') return image_np def detach_to_cpu(x): return x.detach().cpu() def transpose_np(x): return np.transpose(x, [1,2,0]) def tensor_to_gray_im(x): x = detach_to_cpu(x) x = tensor_to_numpy(x) x = transpose_np(x) return x def tensor_to_im(x): x = detach_to_cpu(x) x = inv_im_trans(x).clamp(0, 1) x = tensor_to_numpy(x) x = transpose_np(x) return x # Predefined key <-> caption dict key_captions = { 'im': 'Image', 'gt': 'GT', } """ Return an image array with captions keys in dictionary will be used as caption if not provided values should contain lists of cv2 images """ def get_image_array(images, grid_shape, captions={}): h, w = grid_shape cate_counts = len(images) rows_counts = len(next(iter(images.values()))) font = cv2.FONT_HERSHEY_SIMPLEX output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) col_cnt = 0 for k, v in images.items(): # Default as key value itself caption = captions.get(k, k) # Handles new line character dy = 40 for i, line in enumerate(caption.split('\n')): cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), font, 0.8, (255,255,255), 2, cv2.LINE_AA) # Put images for row_cnt, img in enumerate(v): im_shape = img.shape if len(im_shape) == 2: img = img[..., np.newaxis] img = (img * 255).astype('uint8') output_image[(col_cnt+0)*w:(col_cnt+1)*w, (row_cnt+1)*h:(row_cnt+2)*h, :] = img col_cnt += 1 return output_image def base_transform(im, size): im = tensor_to_np_float(im) if len(im.shape) == 3: im = im.transpose((1, 2, 0)) else: im = im[:, :, None] # Resize if im.shape[1] != size: im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) return im.clip(0, 1) def im_transform(im, size): return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) def mask_transform(mask, size): return base_transform(detach_to_cpu(mask), size=size) def out_transform(mask, size): return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) def pool_pairs(images, size, num_objects): req_images = defaultdict(list) b, t = images['rgb'].shape[:2] # limit the number of images saved b = min(2, b) # find max num objects max_num_objects = max(num_objects[:b]) GT_suffix = '' for bi in range(b): GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] for bi in range(b): for ti in range(t): req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size)) for oi in range(max_num_objects): if ti == 0 or oi >= num_objects[bi]: req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) # req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) # req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) else: req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size)) # req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size)) # req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size)) req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) return get_image_array(req_images, size, key_captions) ================================================ FILE: XMem/util/load_subset.py ================================================ """ load_subset.py - Presents a subset of data DAVIS - only the training set YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all """ def load_sub_davis(path='util/davis_subset.txt'): with open(path, mode='r') as f: subset = set(f.read().splitlines()) return subset def load_sub_yv(path='util/yv_subset.txt'): with open(path, mode='r') as f: subset = set(f.read().splitlines()) return subset ================================================ FILE: XMem/util/log_integrator.py ================================================ """ Integrate numerical values for some iterations Typically used for loss computation / logging to tensorboard Call finalize and create a new Integrator when you want to display/log """ import torch class Integrator: def __init__(self, logger, distributed=True, local_rank=0, world_size=1): self.values = {} self.counts = {} self.hooks = [] # List is used here to maintain insertion order self.logger = logger self.distributed = distributed self.local_rank = local_rank self.world_size = world_size def add_tensor(self, key, tensor): if key not in self.values: self.counts[key] = 1 if type(tensor) == float or type(tensor) == int: self.values[key] = tensor else: self.values[key] = tensor.mean().item() else: self.counts[key] += 1 if type(tensor) == float or type(tensor) == int: self.values[key] += tensor else: self.values[key] += tensor.mean().item() def add_dict(self, tensor_dict): for k, v in tensor_dict.items(): self.add_tensor(k, v) def add_hook(self, hook): """ Adds a custom hook, i.e. compute new metrics using values in the dict The hook takes the dict as argument, and returns a (k, v) tuple e.g. for computing IoU """ if type(hook) == list: self.hooks.extend(hook) else: self.hooks.append(hook) def reset_except_hooks(self): self.values = {} self.counts = {} # Average and output the metrics def finalize(self, prefix, it, f=None): for hook in self.hooks: k, v = hook(self.values) self.add_tensor(k, v) for k, v in self.values.items(): if k[:4] == 'hide': continue avg = v / self.counts[k] if self.distributed: # Inplace operation avg = torch.tensor(avg).cuda() torch.distributed.reduce(avg, dst=0) if self.local_rank == 0: avg = (avg/self.world_size).cpu().item() self.logger.log_metrics(prefix, k, avg, it, f) else: # Simple does it self.logger.log_metrics(prefix, k, avg, it, f) ================================================ FILE: XMem/util/logger.py ================================================ """ Dumps things to tensorboard and console """ import os import warnings import torchvision.transforms as transforms from torch.utils.tensorboard import SummaryWriter def tensor_to_numpy(image): image_np = (image.numpy() * 255).astype('uint8') return image_np def detach_to_cpu(x): return x.detach().cpu() def fix_width_trunc(x): return ('{:.9s}'.format('{:0.9f}'.format(x))) class TensorboardLogger: def __init__(self, short_id, id, git_info): self.short_id = short_id if self.short_id == 'NULL': self.short_id = 'DEBUG' if id is None: self.no_log = True warnings.warn('Logging has been disbaled.') else: self.no_log = False self.inv_im_trans = transforms.Normalize( mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) self.inv_seg_trans = transforms.Normalize( mean=[-0.5/0.5], std=[1/0.5]) log_path = os.path.join('.', 'saves', '%s' % id) self.logger = SummaryWriter(log_path) self.log_string('git', git_info) def log_scalar(self, tag, x, step): if self.no_log: warnings.warn('Logging has been disabled.') return self.logger.add_scalar(tag, x, step) def log_metrics(self, l1_tag, l2_tag, val, step, f=None): tag = l1_tag + '/' + l2_tag text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) print(text) if f is not None: f.write(text + '\n') f.flush() self.log_scalar(tag, val, step) def log_im(self, tag, x, step): if self.no_log: warnings.warn('Logging has been disabled.') return x = detach_to_cpu(x) x = self.inv_im_trans(x) x = tensor_to_numpy(x) self.logger.add_image(tag, x, step) def log_cv2(self, tag, x, step): if self.no_log: warnings.warn('Logging has been disabled.') return x = x.transpose((2, 0, 1)) self.logger.add_image(tag, x, step) def log_seg(self, tag, x, step): if self.no_log: warnings.warn('Logging has been disabled.') return x = detach_to_cpu(x) x = self.inv_seg_trans(x) x = tensor_to_numpy(x) self.logger.add_image(tag, x, step) def log_gray(self, tag, x, step): if self.no_log: warnings.warn('Logging has been disabled.') return x = detach_to_cpu(x) x = tensor_to_numpy(x) self.logger.add_image(tag, x, step) def log_string(self, tag, x): print(tag, x) if self.no_log: warnings.warn('Logging has been disabled.') return self.logger.add_text(tag, x) ================================================ FILE: XMem/util/palette.py ================================================ davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' ================================================ FILE: XMem/util/tensor_util.py ================================================ import torch.nn.functional as F def compute_tensor_iu(seg, gt): intersection = (seg & gt).float().sum() union = (seg | gt).float().sum() return intersection, union def compute_tensor_iou(seg, gt): intersection, union = compute_tensor_iu(seg, gt) iou = (intersection + 1e-6) / (union + 1e-6) return iou # STM def pad_divide_by(in_img, d): h, w = in_img.shape[-2:] if h % d > 0: new_h = h + d - h % d else: new_h = h if w % d > 0: new_w = w + d - w % d else: new_w = w lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) pad_array = (int(lw), int(uw), int(lh), int(uh)) out = F.pad(in_img, pad_array) return out, pad_array def unpad(img, pad): if len(img.shape) == 4: if pad[2]+pad[3] > 0: img = img[:,:,pad[2]:-pad[3],:] if pad[0]+pad[1] > 0: img = img[:,:,:,pad[0]:-pad[1]] elif len(img.shape) == 3: if pad[2]+pad[3] > 0: img = img[:,pad[2]:-pad[3],:] if pad[0]+pad[1] > 0: img = img[:,:,pad[0]:-pad[1]] else: raise NotImplementedError return img ================================================ FILE: XMem/util/yv_subset.txt ================================================ 003234408d 0043f083b5 0044fa5fba 005a527edd 0065b171f9 00917dcfc4 00a23ccf53 00ad5016a4 01082ae388 011ac0a06f 013099c098 0155498c85 01694ad9c8 017ac35701 01b80e8e1a 01baa5a4e1 01c3111683 01c4cb5ffe 01c76f0a82 01c783268c 01ed275c6e 01ff60d1fa 020cd28cd2 02264db755 0248626d9a 02668dbffa 0274193026 02d28375aa 02f3a5c4df 031ccc99b1 0321b18c10 0348a45bca 0355e92655 0358b938c1 0368107cf1 0379ddf557 038b2cc71d 038c15a5dd 03a06cc98a 03a63e187f 03c95b4dae 03e2b57b0e 04194e1248 0444918a5f 04460a7a52 04474174a4 0450095513 045f00aed2 04667fabaa 04735c5030 04990d1915 04d62d9d98 04f21da964 04fbad476e 04fe256562 0503bf89c9 0536c9eed0 054acb238f 05579ca250 056c200404 05774f3a2c 058a7592c8 05a0a513df 05a569d8aa 05aa652648 05d7715782 05e0b0f28f 05fdbbdd7a 05ffcfed85 0630391881 06840b2bbe 068f7dce6f 0693719753 06ce2b51fb 06e224798e 06ee361788 06fbb3fa2c 0700264286 070c918ca7 07129e14a4 07177017e9 07238ffc58 07353b2a89 0738493cbf 075926c651 075c701292 0762ea9a30 07652ee4af 076f206928 077d32af19 079049275c 07913cdda7 07a11a35e8 07ac33b6df 07b6e8fda8 07c62c3d11 07cc1c7d74 080196ef01 081207976e 081ae4fa44 081d8250cb 082900c5d4 0860df21e2 0866d4c5e3 0891ac2eb6 08931bc458 08aa2705d5 08c8450db7 08d50b926c 08e1e4de15 08e48c1a48 08f561c65e 08feb87790 09049f6fe3 092e4ff450 09338adea8 093c335ccc 0970d28339 0974a213dc 097b471ed8 0990941758 09a348f4fa 09a6841288 09c5bad17b 09c9ce80c7 09ff54fef4 0a23765d15 0a275e7f12 0a2f2bd294 0a7a2514aa 0a7b27fde9 0a8c467cc3 0ac8c560ae 0b1627e896 0b285c47f6 0b34ec1d55 0b5b5e8e5a 0b68535614 0b6f9105fc 0b7dbfa3cb 0b9cea51ca 0b9d012be8 0bcfc4177d 0bd37b23c1 0bd864064c 0c11c6bf7b 0c26bc77ac 0c3a04798c 0c44a9d545 0c817cc390 0ca839ee9a 0cd7ac0ac0 0ce06e0121 0cfe974a89 0d2fcc0dcd 0d3aad05d2 0d40b015f4 0d97fba242 0d9cc80d7e 0dab85b6d3 0db5c427a5 0dbaf284f1 0de4923598 0df28a9101 0e04f636c4 0e05f0e232 0e0930474b 0e27472bea 0e30020549 0e621feb6c 0e803c7d73 0e9ebe4e3c 0e9f2785ec 0ea68d418b 0eb403a222 0ee92053d6 0eefca067f 0f17fa6fcb 0f1ac8e9a3 0f202e9852 0f2ab8b1ff 0f51a78756 0f5fbe16b0 0f6072077b 0f6b69b2f4 0f6c2163de 0f74ec5599 0f9683715b 0fa7b59356 0fb173695b 0fc958cde2 0fe7b1a621 0ffcdb491c 101caff7d4 1022fe8417 1032e80b37 103f501680 104e64565f 104f1ab997 106242403f 10b31f5431 10eced835e 110d26fa3a 1122c1d16a 1145b49a5f 11485838c2 114e7676ec 1157472b95 115ee1072c 1171141012 117757b4b8 1178932d2f 117cc76bda 1180cbf814 1187bbd0e3 1197e44b26 119cf20728 119dd54871 11a0c3b724 11a6ba8c94 11c722a456 11cbcb0b4d 11ccf5e99d 11ce6f452e 11e53de6f2 11feabe596 120cb9514d 12156b25b3 122896672d 1232b2f1d4 1233ac8596 1239c87234 1250423f7c 1257a1bc67 125d1b19dd 126d203967 1295e19071 12ad198c54 12bddb2bcb 12ec9b93ee 12eebedc35 132852e094 1329409f2a 13325cfa14 134d06dbf9 135625b53d 13870016f9 13960b3c84 13adaad9d9 13ae097e20 13e3070469 13f6a8c20d 1416925cf2 142d2621f5 145d5d7c03 145fdc3ac5 1471274fa7 14a6b5a139 14c21cea0d 14dae0dc93 14f9bd22b5 14fd28ae99 15097d5d4e 150ea711f2 1514e3563f 152aaa3a9e 152b7d3bd7 15617297cc 15abbe0c52 15d1fb3de5 15f67b0fab 161eb59aad 16288ea47f 164410ce62 165c3c8cd4 165c42b41b 165ec9e22b 1669502269 16763cccbb 16adde065e 16af445362 16afd538ad 16c3fa4d5d 16d1d65c27 16e8599e94 16fe9fb444 1705796b02 1724db7671 17418e81ea 175169edbb 17622326fd 17656bae77 17b0d94172 17c220e4f6 17c7bcd146 17cb4afe89 17cd79a434 17d18604c3 17d8ca1a37 17e33f4330 17f7a6d805 180abc8378 183ba3d652 185bf64702 18913cc690 1892651815 189ac8208a 189b44e92c 18ac264b76 18b245ab49 18b5cebc34 18bad52083 18bb5144d5 18c6f205c5 1903f9ea15 1917b209f2 191e74c01d 19367bb94e 193ffaa217 19696b67d3 197f3ab6f3 1981e763cc 198afe39ae 19a6e62b9b 19b60d5335 19c00c11f9 19e061eb88 19e8bc6178 19ee80dac6 1a25a9170a 1a359a6c1a 1a3e87c566 1a5fe06b00 1a6c0fbd1e 1a6f3b5a4b 1a8afbad92 1a8bdc5842 1a95752aca 1a9c131cb7 1aa3da3ee3 1ab27ec7ea 1abf16d21d 1acd0f993b 1ad202e499 1af8d2395d 1afd39a1fa 1b2d31306f 1b3fa67f0e 1b43fa74b4 1b73ea9fc2 1b7e8bb255 1b8680f8cd 1b883843c0 1b8898785b 1b88ba1aa4 1b96a498e5 1bbc4c274f 1bd87fe9ab 1c4090c75b 1c41934f84 1c72b04b56 1c87955a3a 1c9f9eb792 1ca240fede 1ca5673803 1cada35274 1cb44b920d 1cd10e62be 1d3087d5e5 1d3685150a 1d6ff083aa 1d746352a6 1da256d146 1da4e956b1 1daf812218 1dba687bce 1dce57d05d 1de4a9e537 1dec5446c8 1dfbe6f586 1e1a18c45a 1e1e42529d 1e4be70796 1eb60959c8 1ec8b2566b 1ecdc2941c 1ee0ac70ff 1ef8e17def 1f1a2a9fc0 1f1beb8daa 1f2609ee13 1f3876f8d0 1f4ec0563d 1f64955634 1f7d31b5b2 1f8014b7fd 1f9c7d10f1 1fa350df76 1fc9538993 1fe2f0ec59 2000c02f9d 20142b2f05 201a8d75e5 2023b3ee4f 202b767bbc 203594a418 2038987336 2039c3aecb 204a90d81f 207bc6cf01 208833d1d1 20c6d8b362 20e3e52e0a 2117fa0c14 211bc5d102 2120d9c3c3 2125235a49 21386f5978 2142af8795 215dfc0f73 217bae91e5 217c0d44e4 219057c87b 21d0edbf81 21df87ad76 21f1d089f5 21f4019116 222597030f 222904eb5b 223a0e0657 223bd973ab 22472f7395 224e7c833e 225aba51d9 2261d421ea 2263a8782b 2268cb1ffd 2268e93b0a 2293c99f3f 22a1141970 22b13084b2 22d9f5ab0c 22f02efe3a 232c09b75b 2350d71b4b 2376440551 2383d8aafd 238b84e67f 238d4b86f6 238d947c6b 23993ce90d 23b0c8a9ab 23b3beafcc 23d80299fe 23f404a9fc 240118e58a 2431dec2fd 24440e0ac7 2457274dbc 2465bf515d 246b142c4d 247d729e36 2481ceafeb 24866b4e6a 2489d78320 24ab0b83e8 24b0868d92 24b5207cd9 24ddf05c03 250116161c 256ad2e3fc 256bd83d5e 256dcc8ab8 2589956baa 258b3b33c6 25ad437e29 25ae395636 25c750c6db 25d2c3fe5d 25dc80db7c 25f97e926f 26011bc28b 260846ffbe 260dd9ad33 267964ee57 2680861931 268ac7d3fc 26b895d91e 26bc786d4f 26ddd2ef12 26de3d18ca 26f7784762 2703e52a6a 270ed80c12 2719b742ab 272f4163d0 27303333e1 27659fa7d6 279214115d 27a5f92a9c 27cf2af1f3 27f0d5f8a2 28075f33c1 281629cb41 282b0d51f5 282fcab00b 28449fa0dc 28475208ca 285580b7c4 285b69e223 288c117201 28a8eb9623 28bf9c3cf3 28c6b8f86a 28c972dacd 28d9fa6016 28e392de91 28f4a45190 298c844fc9 29a0356a2b 29d779f9e3 29dde5f12b 29de7b6579 29e630bdd0 29f2332d30 2a18873352 2a3824ff31 2a559dd27f 2a5c09acbd 2a63eb1524 2a6a30a4ea 2a6d9099d1 2a821394e3 2a8c5b1342 2abc8d66d2 2ac9ef904a 2b08f37364 2b351bfd7d 2b659a49d7 2b69ee5c26 2b6c30bbbd 2b88561cf2 2b8b14954e 2ba621c750 2bab50f9a7 2bb00c2434 2bbde474ef 2bdd82fb86 2be06fb855 2bf545c2f5 2bffe4cf9a 2c04b887b7 2c05209105 2c0ad8cf39 2c11fedca8 2c1a94ebfb 2c1e8c8e2f 2c29fabcf1 2c2c076c01 2c3ea7ee7d 2c41fa0648 2c44bb6d1c 2c54cfbb78 2c5537eddf 2c6e63b7de 2cb10c6a7e 2cbcd5ccd1 2cc5d9c5f6 2cd01cf915 2cdbf5f0a7 2ce660f123 2cf114677e 2d01eef98e 2d03593bdc 2d183ac8c4 2d33ad3935 2d3991d83e 2d4333577b 2d4d015c64 2d8f5e5025 2d900bdb8e 2d9a1a1d49 2db0576a5c 2dc0838721 2dcc417f82 2df005b843 2df356de14 2e00393d96 2e03b8127a 2e0f886168 2e2bf37e6d 2e42410932 2ea78f46e4 2ebb017a26 2ee2edba2a 2efb07554a 2f17e4fc1e 2f2c65c2f3 2f2d9b33be 2f309c206b 2f53822e88 2f53998171 2f5b0c89b1 2f680909e6 2f710f66bd 2f724132b9 2f7e3517ae 2f96f5fc6f 2f97d9fecb 2fbfa431ec 2fc9520b53 2fcd9f4c62 2feb30f208 2ff7f5744f 30085a2cc6 30176e3615 301f72ee11 3026bb2f61 30318465dc 3054ca937d 306121e726 3064ad91e8 307444a47f 307bbb7409 30a20194ab 30c35c64a4 30dbdb2cd6 30fc77d72f 310021b58b 3113140ee8 3150b2ee57 31539918c4 318dfe2ce2 3193da4835 319f725ad9 31bbd0d793 322505c47f 322b237865 322da43910 3245e049fb 324c4c38f6 324e35111a 3252398f09 327dc4cabf 328d918c7d 3290c0de97 3299ae3116 32a7cd687b 33098cedb4 3332334ac4 334cb835ac 3355e056eb 33639a2847 3373891cdc 337975816b 33e29d7e91 34046fe4f2 3424f58959 34370a710f 343bc6a65a 3450382ef7 3454303a08 346aacf439 346e92ff37 34a5ece7dd 34b109755a 34d1b37101 34dd2c70a7 34efa703df 34fbee00a6 3504df2fda 35195a56a1 351c822748 351cfd6bc5 3543d8334c 35573455c7 35637a827f 357a710863 358bf16f9e 35ab34cc34 35c6235b8d 35d01a438a 3605019d3b 3609bc3f88 360e25da17 36299c687c 362c5bc56e 3649228783 365b0501ea 365f459863 369893f3ad 369c9977e1 369dde050a 36c7dac02f 36d5b1493b 36f5cc68fd 3735480d18 374b479880 375a49d38f 375a5c0e09 376bda9651 377db65f60 37c19d1087 37d4ae24fc 37ddce7f8b 37e10d33af 37e45c6247 37fa0001e8 3802d458c0 382caa3cb4 383bb93111 388843df90 38924f4a7f 38b00f93d7 38c197c10e 38c9c3d801 38eb2bf67f 38fe9b3ed1 390352cced 390c51b987 390ca6f1d6 392bc0f8a1 392ecb43bd 3935291688 3935e63b41 394454fa9c 394638fc8b 39545e20b7 397abeae8f 3988074b88 398f5d5f19 39bc49a28c 39befd99fb 39c3c7bf55 39d584b09f 39f6f6ffb1 3a079fb484 3a0d3a81b7 3a1d55d22b 3a20a7583e 3a2c1f66e5 3a33f4d225 3a3bf84b13 3a4565e5ec 3a4e32ed5e 3a7ad86ce0 3a7bdde9b8 3a98867cbe 3aa3f1c9e8 3aa7fce8b6 3aa876887d 3ab807ded6 3ab9b1a85a 3adac8d7da 3ae1a4016f 3ae2deaec2 3ae81609d6 3af847e62f 3b23792b84 3b3b0af2ee 3b512dad74 3b6c7988f6 3b6e983b5b 3b74a0fc20 3b7a50b80d 3b96d3492f 3b9ad0c5a9 3b9ba0894a 3bb4e10ed7 3bd9a9b515 3beef45388 3c019c0a24 3c090704aa 3c2784fc0d 3c47ab95f8 3c4db32d74 3c5ff93faf 3c700f073e 3c713cbf2f 3c8320669c 3c90d225ee 3cadbcc404 3cb9be84a5 3cc37fd487 3cc6f90cb2 3cd5e035ef 3cdf03531b 3cdf828f59 3d254b0bca 3d5aeac5ba 3d690473e1 3d69fed2fb 3d8997aeb6 3db0d6b07e 3db1ddb8cf 3db907ac77 3dcbc0635b 3dd48ed55f 3de4ac4ec4 3decd63d88 3e04a6be11 3e108fb65a 3e1448b01c 3e16c19634 3e2845307e 3e38336da5 3e3a819865 3e3e4be915 3e680622d7 3e7d2aeb07 3e7d8f363d 3e91f10205 3ea4c49bbe 3eb39d11ab 3ec273c8d5 3ed3f91271 3ee062a2fd 3eede9782c 3ef2fa99cb 3efc6e9892 3f0b0dfddd 3f0c860359 3f18728586 3f3b15f083 3f45a470ad 3f4f3bc803 3fd96c5267 3fea675fab 3fee8cbc9f 3fff16d112 401888b36c 4019231330 402316532d 402680df52 404d02e0c0 40709263a8 4083cfbe15 40a96c5cb1 40b8e50f82 40f4026bf5 4100b57a3a 41059fdd0b 41124e36de 4122aba5f9 413bab0f0d 4164faee0b 418035eec9 4182d51532 418bb97e10 41a34c20e7 41dab05200 41ff6d5e2a 420caf0859 42264230ba 425a0c96e0 42da96b87c 42eb5a5b0f 42f17cd14d 42f5c61c49 42ffdcdee9 432f9884f9 43326d9940 4350f3ab60 4399ffade3 43a6c21f37 43b5555faa 43d63b752a 4416bdd6ac 4444753edd 444aa274e7 444d4e0596 446b8b5f7a 4478f694bb 44b1da0d87 44b4dad8c9 44b5ece1b9 44d239b24e 44eaf8f51e 44f4f57099 44f7422af2 450787ac97 4523656564 4536c882e5 453b65daa4 454f227427 45636d806a 456fb9362e 457e717a14 45a89f35e1 45bf0e947d 45c36a9eab 45d9fc1357 45f8128b97 4607f6c03c 46146dfd39 4620e66b1e 4625f3f2d3 462b22f263 4634736113 463c0f4fdd 46565a75f8 46630b55ae 466839cb37 466ba4ae0c 4680236c9d 46bf4e8709 46e18e42f1 46f5093c59 47269e0499 472da1c484 47354fab09 4743bb84a7 474a796272 4783d2ab87 479cad5da3 479f5d7ef6 47a05fbd1d 4804ee2767 4810c3fbca 482fb439c2 48375af288 484ab44de4 485f3944cd 4867b84887 486a8ac57e 486e69c5bd 48812cf33e 4894b3b9ea 48bd66517d 48d83b48a4 49058178b8 4918d10ff0 4932911f80 49405b7900 49972c2d14 499bf07002 49b16e9377 49c104258e 49c879f82d 49e7326789 49ec3e406a 49fbf0c98a 4a0255c865 4a088fe99a 4a341402d0 4a3471bdf5 4a4b50571c 4a50f3d2e9 4a6e3faaa1 4a7191f08a 4a86fcfc30 4a885fa3ef 4a8af115de 4aa2e0f865 4aa9d6527f 4abb74bb52 4ae13de1cd 4af8cb323f 4b02c272b3 4b19c529fb 4b2974eff4 4b3154c159 4b54d2587f 4b556740ff 4b67aa9ef6 4b97cc7b8d 4baa1ed4aa 4bc8c676bb 4beaea4dbe 4bf5763d24 4bffa92b67 4c25dfa8ec 4c397b6fd4 4c51e75d66 4c7710908f 4c9b5017be 4ca2ffc361 4cad2e93bc 4cd427b535 4cd9a4b1ef 4cdfe3c2b2 4cef87b649 4cf208e9b3 4cf5bc3e60 4cfdd73249 4cff5c9e42 4d26d41091 4d5c23c554 4d67c59727 4d983cad9f 4da0d00b55 4daa179861 4dadd57153 4db117e6c5 4de4ce4dea 4dfaee19e5 4dfdd7fab0 4e3f346aa5 4e49c2a9c7 4e4e06a749 4e70279712 4e72856cc7 4e752f8075 4e7a28907f 4e824b9247 4e82b1df57 4e87a639bc 4ea77bfd15 4eb6fc23a2 4ec9da329e 4efb9a0720 4f062fbc63 4f35be0e0b 4f37e86797 4f414dd6e7 4f424abded 4f470cc3ae 4f601d255a 4f7386a1ab 4f824d3dcd 4f827b0751 4f8db33a13 4fa160f8a3 4fa9c30a45 4facd8f0e8 4fca07ad01 4fded94004 4fdfef4dea 4feb3ac01f 4fffec8479 500c835a86 50168342bf 50243cffdc 5031d5a036 504dd9c0fd 50568fbcfb 5069c7c5b3 508189ac91 50b6b3d4b7 50c6f4fe3e 50cce40173 50efbe152f 50f290b95d 5104aa1fea 5110dc72c0 511e8ecd7f 513aada14e 5158d6e985 5161e1fa57 51794ddd58 517d276725 51a597ee04 51b37b6d97 51b5dc30a0 51e85b347b 51eea1fdac 51eef778af 51f384721c 521cfadcb4 52355da42f 5247d4b160 524b470fd0 524cee1534 5252195e8a 5255c9ca97 525928f46f 526df007a7 529b12de78 52c7a3d653 52c8ec0373 52d225ed52 52ee406d9e 52ff1ccd4a 53143511e8 5316d11eb7 53253f2362 534a560609 5352c4a70e 536096501f 536b17bcea 5380eaabff 5390a43a54 53af427bb2 53bf5964ce 53c30110b5 53cad8e44a 53d9c45013 53e274f1b5 53e32d21ea 540850e1c7 540cb31cfe 541c4da30f 541d7935d7 545468262b 5458647306 54657855cd 547b3fb23b 5497dc3712 549c56f1d4 54a4260bb1 54b98b8d5e 54e1054b0f 54e8867b83 54ebe34f6e 5519b4ad13 551acbffd5 55341f42da 5566ab97e1 556c79bbf2 5589637cc4 558aa072f0 559824b6f6 55c1764e90 55eda6c77e 562d173565 5665c024cb 566cef4959 5675d78833 5678a91bd8 567a2b4bd0 569c282890 56cc449917 56e71f3e07 56f09b9d92 56fc0e8cf9 571ca79c71 57243657cf 57246af7d1 57427393e9 574b682c19 578f211b86 5790ac295d 579393912d 57a344ab1a 57bd3bcda4 57bfb7fa4c 57c010175e 57c457cc75 57c7fc2183 57d5289a01 58045fde85 58163c37cd 582d463e5c 5851739c15 585dd0f208 587250f3c3 589e4cc1de 589f65f5d5 58a07c17d5 58adc6d8b6 58b9bcf656 58c374917e 58fc75fd42 5914c30f05 59323787d5 5937b08d69 594065ddd7 595a0ceea6 59623ec40b 597ff7ef78 598935ef05 598c2ad3b2 59a6459751 59b175e138 59bf0a149f 59d53d1649 59e3e6fae7 59fe33e560 5a13a73fe5 5a25c22770 5a4a785006 5a50640995 5a75f7a1cf 5a841e59ad 5a91c5ab6d 5ab49d9de0 5aba1057fe 5abe46ba6d 5ac7c88d0c 5aeb95cc7d 5af15e4fc3 5afe381ae4 5b07b4229d 5b1001cc4f 5b1df237d2 5b263013bf 5b27d19f0b 5b48ae16c5 5b5babc719 5baaebdf00 5bab55cdbe 5bafef6e79 5bd1f84545 5bddc3ba25 5bdf7c20d2 5bf23bc9d3 5c01f6171a 5c021681b7 5c185cff1d 5c42aba280 5c44bf8ab6 5c4c574894 5c52fa4662 5c6ea7dac3 5c74315dc2 5c7668855e 5c83e96778 5ca36173e4 5cac477371 5cb0cb1b2f 5cb0cfb98f 5cb49a19cf 5cbf7dc388 5d0e07d126 5d1e24b6e3 5d663000ff 5da6b2dc5d 5de9b90f24 5e08de0ed7 5e1011df9a 5e1ce354fd 5e35512dd7 5e418b25f9 5e4849935a 5e4ee19663 5e886ef78f 5e8d00b974 5e8d59dc31 5ed838bd5c 5edda6ee5a 5ede4d2f7a 5ede9767da 5eec4d9fe5 5eecf07824 5eef7ed4f4 5ef5860ac6 5ef6573a99 5f1193e72b 5f29ced797 5f32cf521e 5f51876986 5f6ebe94a9 5f6f14977c 5f808d0d2d 5fb8aded6a 5fba90767d 5fd1c7a3df 5fd3da9f68 5fee2570ae 5ff66140d6 5ff8b85b53 600803c0f6 600be7f53e 6024888af8 603189a03c 6057307f6e 6061ddbb65 606c86c455 60c61cc2e5 60e51ff1ae 610e38b751 61344be2f6 6135e27185 614afe7975 614e571886 614e7078db 619812a1a7 61b481a78b 61c7172650 61cf7e40d2 61d08ef5a1 61da008958 61ed178ecb 61f5d1282c 61fd977e49 621584cffe 625817a927 625892cf0b 625b89d28a 629995af95 62a0840bb5 62ad6e121c 62d6ece152 62ede7b2da 62f025e1bc 6316faaebc 63281534dc 634058dda0 6353f09384 6363c87314 636e4872e0 637681cd6b 6376d49f31 6377809ec2 63936d7de5 639bddef11 63d37e9fd3 63d90c2bae 63e544a5d6 63ebbcf874 63fff40b31 6406c72e4d 64148128be 6419386729 643092bc41 644081b88d 64453cf61d 644bad9729 6454f548fd 645913b63a 64750b825f 64a43876b7 64dd6c83e3 64e05bf46e 64f55f1478 650b0165e4 651066ed39 652b67d960 653821d680 6538d00d73 65866dce22 6589565c8c 659832db64 65ab7e1d98 65b7dda462 65bd5eb4f5 65dcf115ab 65e9825801 65f9afe51c 65ff12bcb5 666b660284 6671643f31 668364b372 66852243cb 6693a52081 669b572898 66e98e78f5 670f12e88f 674c12c92d 675c27208a 675ed3e1ca 67741db50a 678a2357eb 67b0f4d562 67cfbff9b1 67e717d6bd 67ea169a3b 67ea809e0e 681249baa3 683de643d9 6846ac20df 6848e012ef 684bcd8812 684dc1c40c 685a1fa9cf 686dafaac9 68807d8601 6893778c77 6899d2dabe 68a2fad4ab 68cb45fda3 68cc4a1970 68dcb40675 68ea4a8c3d 68f6e7fbf0 68fa8300b4 69023db81f 6908ccf557 691a111e7c 6927723ba5 692ca0e1a2 692eb57b63 69340faa52 693cbf0c9d 6942f684ad 6944fc833b 69491c0ebf 695b61a2b0 6979b4d83f 697d4fdb02 69910460a4 6997636670 69a436750b 69aebf7669 69b8c17047 69c67f109f 69e0e7b868 69ea9c09d1 69f0af42a6 6a078cdcc7 6a37a91708 6a42176f2e 6a48e4aea8 6a5977be3a 6a5de0535f 6a80d2e2e5 6a96c8815d 6a986084e2 6aa8e50445 6ab9dce449 6abf0ba6b2 6acc6049d9 6adb31756c 6ade215eb0 6afb7d50e4 6afd692f1a 6b0b1044fe 6b17c67633 6b1b6ef28b 6b1e04d00d 6b2261888d 6b25d6528a 6b3a24395c 6b685eb75b 6b79be238c 6b928b7ba6 6b9c43c25a 6ba99cc41f 6bdab62bcd 6bf2e853b1 6bf584200f 6bf95df2b9 6c0949c51c 6c11a5f11f 6c23d89189 6c4387daf5 6c4ce479a4 6c5123e4bc 6c54265f16 6c56848429 6c623fac5f 6c81b014e9 6c99ea7c31 6c9d29d509 6c9e3b7d1a 6ca006e283 6caeb928d6 6cb2ee722a 6cbfd32c5e 6cc791250b 6cccc985e0 6d12e30c48 6d4bf200ad 6d6d2b8843 6d6eea5682 6d7a3d0c21 6d7efa9b9e 6da21f5c91 6da6adabc0 6dd2827fbb 6dd36705b9 6df3637557 6dfe55e9e5 6e1a21ba55 6e2f834767 6e36e4929a 6e4f460caf 6e618d26b6 6ead4670f7 6eaff19b9f 6eb2e1cd9e 6eb30b3b5a 6eca26c202 6ecad29e52 6ef0b44654 6efcfe9275 6f4789045c 6f49f522ef 6f67d7c4c4 6f96e91d81 6fc6fce380 6fc9b44c00 6fce7f3226 6fdf1ca888 702fd8b729 70405185d2 7053e4f41e 707bf4ce41 7082544248 708535b72a 7094ac0f60 70a6b875fa 70c3e97e41 7106b020ab 711dce6fe2 7136a4453f 7143fb084f 714d902095 7151c53b32 715357be94 7163b8085f 716df1aa59 71caded286 71d2665f35 71d67b9e19 71e06dda39 720b398b9c 720e3fa04c 720e7a5f1e 721bb6f2cb 722803f4f2 72552a07c9 726243a205 72690ef572 728cda9b65 728e81c319 72a810a799 72acb8cdf6 72b01281f9 72cac683e4 72cadebbce 72cae058a5 72d8dba870 72e8d1c1ff 72edc08285 72f04f1a38 731b825695 7320b49b13 732626383b 732df1eb05 73329902ab 733798921e 733824d431 734ea0d7fb 735a7cf7b9 7367a42892 7368d5c053 73c6ae7711 73e1852735 73e4e5cc74 73eac9156b 73f8441a88 7419e2ab3f 74267f68b9 7435690c8c 747c44785c 747f1b1f2f 748b2d5c01 74d4cee0a4 74ec2b3073 74ef677020 750be4c4d8 75172d4ac8 75285a7eb1 75504539c3 7550949b1d 7551cbd537 75595b453d 7559b4b0ec 755bd1fbeb 756f76f74d 7570ca7f3c 757a69746e 757cac96c6 7584129dc3 75a058dbcd 75b09ce005 75cae39a8f 75cee6caf0 75cf58fb2c 75d5c2f32a 75eaf5669d 75f7937438 75f99bd3b3 75fa586876 7613df1f84 762e1b3487 76379a3e69 764271f0f3 764503c499 7660005554 7666351b84 76693db153 767856368b 768671f652 768802b80d 76962c7ed2 76a75f4eee 76b90809f7 770a441457 772a0fa402 772f2ffc3e 774f6c2175 77610860e0 777e58ff3d 77920f1708 7799df28e7 779e847a9a 77ba4edc72 77c834dc43 77d8aa8691 77e7f38f4d 77eea6845e 7806308f33 78254660ea 7828af8bff 784398620a 784d201b12 78613981ed 78896c6baf 78aff3ebc0 78c7c03716 78d3676361 78e29dd4c3 78f1a1a54f 79208585cd 792218456c 7923bad550 794e6fc49f 796e6762ce 797cd21f71 79921b21c2 79a5778027 79bc006280 79bf95e624 79d9e00c55 79e20fc008 79e9db913e 79f014085e 79fcbb433a 7a13a5dfaa 7a14bc9a36 7a3c535f70 7a446a51e9 7a56e759c5 7a5f46198d 7a626ec98d 7a802264c4 7a8b5456ca 7abdff3086 7aecf9f7ac 7b0fd09c28 7b18b3db87 7b39fe7371 7b49e03d4c 7b5388c9f1 7b5cf7837f 7b733d31d8 7b74fd7b98 7b918ccb8a 7ba3ce3485 7bb0abc031 7bb5bb25cd 7bb7dac673 7bc7761b8c 7bf3820566 7c03a18ec1 7c078f211b 7c37d7991a 7c4ec17eff 7c649c2aaf 7c73340ab7 7c78a2266d 7c88ce3c5b 7ca6843a72 7cc9258dee 7cec7296ae 7d0ffa68a4 7d11b4450f 7d1333fcbe 7d18074fef 7d18c8c716 7d508fb027 7d55f791f0 7d74e3c2f6 7d783f67a9 7d83a5d854 7dd409947e 7de45f75e5 7e0cd25696 7e1922575c 7e1e3bbcc1 7e24023274 7e2f212fd3 7e6d1cc1f4 7e7cdcb284 7e9b6bef69 7ea5b49283 7eb2605d96 7eb26b8485 7ecd1f0c69 7f02b3cfe2 7f1723f0d5 7f21063c3a 7f3658460e 7f54132e48 7f559f9d4a 7f5faedf8b 7f838baf2b 7fa5f527e3 7ff84d66dd 802b45c8c4 804382b1ad 804c558adb 804f6338a4 8056117b89 806b6223ab 8088bda461 80b790703b 80c4a94706 80ce2e351b 80db581acd 80e12193df 80e41b608f 80f16b016d 81541b3725 8175486e6a 8179095000 8193671178 81a58d2c6b 81aa1286fb 81dffd30fb 8200245704 823e7a86e8 824973babb 824ca5538f 827171a845 8273a03530 827cf4f886 82b865c7dd 82c1517708 82d15514d6 82e117b900 82fec06574 832b5ef379 83424c9fbf 8345358fb8 834b50b31b 835e3b67d7 836ea92b15 837c618777 838eb3bd89 839381063f 839bc71489 83a8151377 83ae88d217 83ca8bcad0 83ce590d7f 83d3130ba0 83d40bcba5 83daba503a 83de906ec0 84044f37f3 84696b5a5e 84752191a3 847eeeb2e0 848e7835a0 84a4b29286 84a4bf147d 84be115c09 84d95c4350 84e0922cf7 84f0cfc665 8515f6db22 851f2f32c1 852a4d6067 854c48b02a 857a387c86 859633d56a 85a4f4a639 85ab85510c 85b1eda0d9 85dc1041c6 85e081f3c7 85f75187ad 8604bb2b75 860745b042 863b4049d7 8643de22d0 8647d06439 864ffce4fe 8662d9441a 8666521b13 868d6a0685 869fa45998 86a40b655d 86a8ae4223 86b2180703 86c85d27df 86d3755680 86e61829a1 871015806c 871e409c5c 8744b861ce 8749369ba0 878a299541 8792c193a0 8799ab0118 87d1f7d741 882b9e4500 885673ea17 8859dedf41 8873ab2806 887a93b198 8883e991a9 8891aa6dfa 8899d8cbcd 88b8274d67 88d3b80af6 88ede83da2 88f345941b 890976d6da 8909bde9ab 8929c7d5d9 89363acf76 89379487e0 8939db6354 893f658345 8953138465 895c96d671 895cbf96f9 895e8b29a7 898fa256c8 89986c60be 89b874547b 89bdb021d5 89c802ff9c 89d6336c2b 89ebb27334 8a27e2407c 8a31f7bca5 8a4a2fc105 8a5d6c619c 8a75ad7924 8aa817e4ed 8aad0591eb 8aca214360 8ae168c71b 8b0cfbab97 8b3645d826 8b3805dbd4 8b473f0f5d 8b4f6d1186 8b4fb018b7 8b518ee936 8b523bdfd6 8b52fb5fba 8b91036e5c 8b99a77ac5 8ba04b1e7b 8ba782192f 8bbeaad78b 8bd1b45776 8bd7a2dda6 8bdb091ccf 8be56f165d 8be950d00f 8bf84e7d45 8bffc4374b 8bfff50747 8c09867481 8c0a3251c3 8c3015cccb 8c469815cf 8c9ccfedc7 8ca1af9f3c 8ca3f6e6c1 8ca6a4f60f 8cac6900fe 8cba221a1e 8cbbe62ccd 8d064b29e2 8d167e7c08 8d4ab94e1c 8d81f6f899 8d87897d66 8dcccd2bd2 8dcfb878a8 8dd3ab71b9 8dda6bf10f 8ddd51ca94 8dea22c533 8def5bd3bf 8e1848197c 8e3a83cf2d 8e478e73f3 8e98ae3c84 8ea6687ab0 8eb0d315c1 8ec10891f9 8ec3065ec2 8ecf51a971 8eddbab9f7 8ee198467a 8ee2368f40 8ef595ce82 8f0a653ad7 8f1204a732 8f1600f7f6 8f16366707 8f1ce0a411 8f2e05e814 8f320d0e09 8f3b4a84ad 8f3fdad3da 8f5d3622d8 8f62a2c633 8f81c9405a 8f8c974d53 8f918598b6 8ff61619f6 9002761b41 90107941f3 90118a42ee 902bc16b37 903e87e0d6 9041a0f489 9047bf3222 9057bfa502 90617b0954 9076f4b6db 9077e69b08 909655b4a6 909c2eca88 909dbd1b76 90bc4a319a 90c7a87887 90cc785ddd 90d300f09b 9101ea9b1b 9108130458 911ac9979b 9151cad9b5 9153762797 91634ee0c9 916942666f 9198cfb4ea 919ac864d6 91b67d58d4 91bb8df281 91be106477 91c33b4290 91ca7dd9f3 91d095f869 91f107082e 920329dd5e 920c959958 92128fbf4b 9223dacb40 923137bb7f 9268e1f88a 927647fe08 9276f5ba47 92a28cd233 92b5c1fc6d 92c46be756 92dabbe3a0 92e3159361 92ebab216a 934bdc2893 9359174efc 935d97dd2f 935feaba1b 93901858ee 939378f6d6 939bdf742e 93a22bee7e 93da9aeddf 93e2feacce 93e6f1fdf9 93e811e393 93e85d8fd3 93f623d716 93ff35e801 94031f12f2 94091a4873 94125907e3 9418653742 941c870569 94209c86f0 9437c715eb 9445c3eca2 9467c8617c 946d71fb5d 948f3ae6fb 9498baa359 94a33abeab 94bf1af5e3 94cf3a8025 94db712ac8 94e4b66cff 94e76cbaf6 950be91db1 952058e2d0 952633c37f 952ec313fe 9533fc037c 9574b81269 9579b73761 957f7bc48b 958073d2b0 9582e0eb33 9584092d0b 95b58b8004 95bd88da55 95f74a9959 962781c601 962f045bf5 964ad23b44 967b90590e 967bffe201 96825c4714 968492136a 9684ef9d64 968c41829e 96a856ef9a 96dfc49961 96e1a5b4f8 96e6ff0917 96fb88e9d7 96fbe5fc23 96fc924050 9715cc83dc 9720eff40f 972c187c0d 97476eb38d 97659ed431 9773492949 97756b264f 977bff0d10 97ab569ff3 97ba838008 97d9d008c7 97e59f09fa 97eb642e56 98043e2d14 981ff580cf 983e66cbfc 984f0f1c36 98595f2bb4 985c3be474 9869a12362 986b5a5e18 9877af5063 98911292da 9893a3cf77 9893d9202d 98a8b06e7f 98ac6f93d9 98b6974d12 98ba3c9417 98c7c00a19 98d044f206 98e909f9d1 98fe7f0410 990f2742c7 992bd0779a 994b9b47ba 9955b76bf5 9966f3adac 997117a654 999d53d841 99c04108d3 99c4277aee 99c6b1acf2 99dc8bb20b 99fcba71e5 99fecd4efb 9a02c70ba2 9a08e7a6f8 9a2f2c0f86 9a3254a76e 9a3570a020 9a39112493 9a4e9fd399 9a50af4bfb 9a68631d24 9a72318dbf 9a767493b7 9a7fc1548b 9a84ccf6a7 9a9c0e15b7 9adf06d89b 9b22b54ee4 9b473fc8fe 9b4f081782 9b997664ba 9bc454e109 9bccfd04de 9bce4583a2 9bebf1b87f 9bfc50d261 9c166c86ff 9c293ef4d7 9c29c047b0 9c3bc2e2a7 9c3ce23bd1 9c404cac0c 9c5180d23a 9c7feca6e4 9caa49d3ff 9cb2f1b646 9ce6f765c3 9cfee34031 9d01f08ec6 9d04c280b8 9d12ceaddc 9d15f8cb3c 9d2101e9bf 9d407c3aeb 9ddefc6165 9df0b1e298 9e16f115d8 9e249b4982 9e29b1982c 9e493e4773 9e4c752cd0 9e4de40671 9e6319faeb 9e6ddbb52d 9eadcea74f 9ecec5f8ea 9efb47b595 9f30bfe61e 9f3734c3a4 9f5b858101 9f66640cda 9f913803e9 9f97bc74c8 9fbad86e20 9fc2bad316 9fc5c3af78 9fcb310255 9fcc256871 9fd2fd4d47 a0071ae316 a023141022 a046399a74 a066e739c1 a06722ba82 a07a15dd64 a07b47f694 a09c39472e a0b208fe2e a0b61c959e a0bc6c611d a0e6da5ba2 a1193d6490 a14ef483ff a14f709908 a15ccc5658 a16062456f a174e8d989 a177c2733c a17c62e764 a18ad065fc a1aaf63216 a1bb65fb91 a1bd8e5349 a1dfdd0cac a2052e4f6c a20fd34693 a21ffe4d81 a22349e647 a235d01ec1 a24f63e8a2 a2554c9f6d a263ce8a87 a29bfc29ec a2a80072d4 a2a800ab63 a2bcd10a33 a2bdaff3b0 a2c146ab0d a2c996e429 a2dc51ebe8 a2e6608bfa a2f2a55f01 a301869dea a31fccd2cc a34f440f33 a35e0206da a36bdc4cab a36e8c79d8 a378053b20 a37db3a2b3 a38950ebc2 a39a0eb433 a39c9bca52 a3a945dc8c a3b40a0c1e a3b8588550 a3c502bec3 a3f2878017 a3f4d58010 a3f51855c3 a402dc0dfe a4065a7eda a412bb2fef a416b56b53 a41ec95906 a43299e362 a4757bd7af a48c53c454 a49dcf9ad5 a4a506521f a4ba7753d9 a4bac06849 a4f05d681c a50c10060f a50eb5a0ea a5122c6ec6 a522b1aa79 a590915345 a5b5b59139 a5b77abe43 a5c2b2c3e1 a5cd17bb11 a5da03aef1 a5dd11de0d a5ea2b93b6 a5eaeac80b a5ec5b0265 a5f350a87e a5f472caf4 a6027a53cf a61715bb1b a61cf4389d a61d9bbd9b a6470dbbf5 a64a40f3eb a653d5c23b a65bd23cb5 a66e0b7ad4 a66fc5053c a68259572b a6a810a92c a6bc36937f a6c3a374e9 a6d8a4228d a6f4e0817f a71e0481f5 a7203deb2d a7392d4438 a73d3c3902 a7491f1578 a74b9ca19c a77b7a91df a78195a5f5 a78758d4ce a7e6d6c29a a800d85e88 a832fa8790 a83d06410d a8999af004 a8f78125b9 a907b18df1 a919392446 a965504e88 a96b84b8d2 a973f239cd a977126596 a9804f2a08 a984e56893 a99738f24c a99bdd0079 a9c9c1517e a9cbf9c41b a9e42e3c0c aa07b7c1c0 aa175e5ec7 aa1a338630 aa27d7b868 aa45f1caaf aa49e46432 aa51934e1b aa6287bb6c aa6d999971 aa85278334 aab33f0e2a aaba004362 aade4cf385 aae78feda4 aaed233bf3 aaff16c2db ab199e8dfb ab23b78715 ab2e1b5577 ab33a18ded ab45078265 ab56201494 ab90f0d24b abab2e6c20 abb50c8697 abbe2d15a0 abbe73cd21 abe61a11bb abeae8ce21 ac2b431d5f ac2cb1b9eb ac31fcd6d0 ac3d3a126d ac46bd8087 ac783ef388 acb73e4297 acbf581760 accafc3531 acf2c4b745 acf44293a2 acf736a27b acff336758 ad1fe56886 ad28f9b9d9 ad2de9f80e ad397527b2 ad3d1cfbcb ad3fada9d9 ad4108ee8e ad54468654 ad573f7d31 ad6255bc29 ad65ebaa07 ad97cc064a adabbd1cc4 adb0b5a270 adc648f890 add21ee467 adfd15ceef adfdd52eac ae01cdab63 ae0b50ff4f ae13ee3d70 ae1bcbd423 ae20d09dea ae2cecf5f6 ae3bc4a0ef ae499c7514 ae628f2cd4 ae8545d581 ae93214fe6 ae9cd16dbf aeba9ac967 aebb242b5c aed4e0b4c4 aedd71f125 aef3e2cb0e af0b54cee3 af3de54c7a af5fd24a36 af8826d084 af8ad72057 afb71e22c5 afcb331e1f afe1a35c1e b01080b5d3 b05ad0d345 b0623a6232 b064dbd4b7 b06ed37831 b06f5888e6 b08dcc490e b0a68228dc b0aece727f b0b0731606 b0c7f11f9f b0cca8b830 b0dd580a89 b0de66ca08 b0df7c5c5c b0f5295608 b11099eb09 b132a53086 b1399fac64 b13abc0c69 b1457e3b5e b15bf4453b b179c4a82d b17ee70e8c b190b1aa65 b19b3e22c0 b19c561fab b1d1cd2e6e b1d7c03927 b1d7fe2753 b1f540a4bd b1fc9c64e1 b1fcbb3ced b220939e93 b22099b419 b241e95235 b2432ae86d b2456267df b247940d01 b24af1c35c b24f600420 b24fe36b2a b258fb0b7d b26b219919 b26d9904de b274456ce1 b27b28d581 b2a26bc912 b2a9c51e1b b2b0baf470 b2b2756fe7 b2ce7699e3 b2edc76bd2 b2f6b52100 b30bf47bcd b34105a4e9 b372a82edf b3779a1962 b379ab4ff5 b37a1d69e3 b37c01396e b382b09e25 b3996e4ba5 b3d9ca2aee b3dde1e1e9 b3eb7f05eb b40b25055c b41e0f1f19 b44e32a42b b4805ae9cd b4807569a5 b48efceb3e b493c25c7f b4b565aba1 b4b715a15b b4d0c90bf4 b4d84bc371 b4e5ad97aa b4eaea9e6b b50f4b90d5 b53f675641 b54278cd43 b554843889 b573c0677a b58d853734 b5943b18ab b5a09a83f3 b5aae1fe25 b5b9da5364 b5eb64d419 b5ebb1d000 b5f1c0c96a b5f7fece90 b6070de1bb b60a76fe73 b61f998772 b62c943664 b63094ba0c b64fca8100 b673e7dcfb b678b7db00 b68fc1b217 b69926d9fa b6a1df3764 b6a4859528 b6b4738b78 b6b4f847b7 b6b8d502d4 b6bb00e366 b6d65a9eef b6d79a0845 b6e9ec577f b6ec609f7b b6f92a308d b70a2c0ab1 b70a5a0d50 b70c052f2f b70d231781 b72ac6e10b b7302d8226 b73867d769 b751e767f2 b76df6e059 b77e5eddef b7a2c2c83c b7bcbe6466 b7c2a469c4 b7d69da8f0 b7f31b7c36 b7f675fb98 b7fb871660 b82e5ad1c9 b841cfb932 b84b8ae665 b85b78ac2b b86c17caa6 b86e50d82d b871db031a b87d56925a b8aaa59b75 b8c03d1091 b8c3210036 b8e16df00b b8f34cf72e b8fb75864e b9004db86c b9166cbae9 b920b256a6 b938d79dff b93963f214 b941aef1a0 b94d34d14e b964c57da4 b96a95bc7a b96c57d2c7 b9b6bdde0c b9bcb3e0f2 b9d3b92169 b9dd4b306c b9f43ef41e ba1f03c811 ba3a775d7b ba3c7f2a31 ba3fcd417d ba5e1f4faa ba795f3089 ba8a291e6a ba98512f97 bac9db04f5 baedae3442 baff40d29d bb04e28695 bb1b0ee89f bb1c770fe7 bb1fc34f99 bb2d220506 bb334e5cdb bb337f9830 bb721eb9aa bb87ff58bd bb89a6b18a bbaa9a036a bbb4302dda bbd31510cf bbe0256a75 bc141b9ad5 bc17ab8a99 bc318160de bc3b9ee033 bc4240b43c bc4ce49105 bc4f71372d bc6b8d6371 bcaad44ad7 bcc241b081 bcc5d8095e bcd1d39afb bd0d849da4 bd0e9ed437 bd2c94730f bd321d2be6 bd3ec46511 bd5b2e2848 bd7e02b139 bd96f9943a bda224cb25 bda4a82837 bdb74e333f bdccd69dde bddcc15521 be116aab29 be15e18f1e be1a284edb be2a367a7b be376082d0 be3e3cffbd be5d1d89a0 be8b72fe37 be9b29e08e bea1f6e62c bea83281b5 beb921a4c9 bec5e9edcd beeb8a3f92 bf2232b58d bf28751739 bf443804e8 bf461df850 bf5374f122 bf551a6f60 bf8d0f5ada bf961167a6 bfab1ad8f9 bfcb05d88d bfd8f6e6c9 bfd91d0742 bfe262322f c013f42ed7 c01878083f c01faff1ed c046fd0edb c053e35f97 c079a6482d c0847b521a c0a1e06710 c0e8d4635c c0e973ad85 c0f49c6579 c0f5b222d7 c10d07c90d c1268d998c c130c3fc0c c14826ad5e c15b922281 c16f09cb63 c18e19d922 c1c830a735 c1e8aeea45 c20a5ccc99 c20fd5e597 c219d6f8dc c2406ae462 c26f7b5824 c279e641ee c27adaeac5 c2a35c1cda c2a9903b8b c2b62567c1 c2b974ec8c c2baaff7bf c2be6900f2 c304dd44d5 c307f33da2 c30a7b62c9 c3128733ee c31fa6c598 c325c8201e c32d4aa5d1 c33f28249a c34365e2d7 c3457af795 c34d120a88 c3509e728d c35e4fa6c4 c36240d96f c3641dfc5a c37b17a4a9 c39559ddf6 c3b0c6e180 c3b3d82e6c c3be369fdb c3bf1e40c2 c3c760b015 c3dd38bf98 c3e4274614 c3edc48cbd c41e6587f5 c4272227b0 c42917fe82 c438858117 c44676563f c44beb7472 c45411dacb c4571bedc8 c46deb2956 c479ee052e c47d551843 c49f07d46d c4cc40c1fc c4f256f5d5 c4f5b1ddcc c4ff9b4885 c52bce43db c544da6854 c55784c766 c557b69fbf c593a3f7ab c598faa682 c5ab1f09c8 c5b6da8602 c5b9128d94 c5e845c6b7 c5fba7b341 c60897f093 c61fe6ed7c c62188c536 c64035b2e2 c69689f177 c6a12c131f c6bb6d2d5c c6c18e860f c6d9526e0d c6e55c33f0 c7030b28bd c70682c7cc c70f9be8c5 c71f30d7b6 c73c8e747f c760eeb8b3 c7637cab0a c7a1a17308 c7bf937af5 c7c2860db3 c7cef4aee2 c7ebfc5d57 c813dcf13c c82235a49a c82a7619a1 c82ecb90cb c844f03dc7 c8557963f3 c89147e6e8 c8a46ff0c8 c8ab107dd5 c8b869a04a c8c7b306a6 c8c8b28781 c8d79e3163 c8edab0415 c8f494f416 c8f6cba9fd c909ceea97 c9188f4980 c922365dd4 c92c8c3c75 c937eb0b83 c94b31b5e5 c95cd17749 c96379c03c c96465ee65 c965afa713 c9734b451f c9862d82dc c98b6fe013 c9999b7c48 c99e92aaf0 c9b3a8fbda c9bf64e965 c9c3cb3797 c9d1c60cd0 c9de9c22c4 ca1828fa54 ca346f17eb ca3787d3d3 ca4b99cbac ca91c69e3b ca91e99105 caa8e97f81 caac5807f8 cabba242c2 cad5a656a9 cad673e375 cad8a85930 cae7b0a02b cae7ef3184 caeb6b6cbb caecf0a5db cb15312003 cb2e35d610 cb35a87504 cb3f22b0cf cbb410da64 cc8728052e cc892997b8 cce03c2a9b cd47a23e31 cd4dc03dc0 cd5ae611da cd603bb9d1 cd8f49734c cdc6b1c032 cdcfe008ad cdd57027c2 ce1af99b4b ce1bc5743a ce25872021 ce2776f78f ce49b1f474 ce4f0a266f ce5641b195 ce6866aa19 ce712ed3c9 ce7d1c8117 ce7dbeaa88 ce9b015a5e cea7697b25 cebbd826cf cec3415361 cec41ad4f4 ced49d26df ced7705ab2 cef824a1e1 cf13f5c95a cf4376a52d cf85ab28b5 cfc2e50b9d cfcd571fff cfd9d4ae47 cfda2dcce5 cff035928b cff8191891 d01608c2a5 d01a8f1f83 d021d68bca d04258ca14 d0483573dc d04a90aaff d05279c0bd d0696bd5fc d072fda75b d0a83bcd9f d0ab39112e d0acde820f d0b4442c71 d0c65e9e95 d0fb600c73 d107a1457c d123d674c1 d14d1e9289 d154e3388e d177e9878a d1802f69f8 d182c4483a d195d31128 d200838929 d205e3cff5 d247420c4c d2484bff33 d26f6ed9b0 d280fcd1cb d2857f0faa d292a50c7f d295ea2dc7 d2a58b4fa6 d2b026739a d2ebe0890f d2ede5d862 d301ca58cc d3069da8bb d343d4a77d d355e634ef d367fb5253 d36d16358e d38bc77e2c d38d1679e2 d3932ad4bd d3987b2930 d39934abe3 d3ae1c3f4c d3b088e593 d3e6e05e16 d3eefae7c5 d3f55f5ab8 d3f5c309cc d4034a7fdf d4193011f3 d429c67630 d42c0ff975 d44a764409 d44e6acd1d d45158c175 d454e8444f d45f62717e d48ebdcf74 d49ab52a25 d4a607ad81 d4b063c7db d4da13e9ba d4dd1a7d00 d4f4f7c9c3 d521aba02e d535bb1b97 d53b955f78 d55cb7a205 d55f247a45 d5695544d8 d5853d9b8b d5b6c6d94a d5cae12834 d5df027f0c d5ee40e5d0 d600046f73 d632fd3510 d6476cad55 d65a7bae86 d664c89912 d689658f06 d6917db4be d69967143e d699d3d798 d69f757a3f d6ac0e065c d6c02bfda5 d6c1b5749e d6e12ef6cc d6eed152c4 d6faaaf726 d704766646 d708e1350c d7135cf104 d7157a9f44 d719cf9316 d724134cfd d73a60a244 d7411662da d74875ea7c d756f5a694 d7572b7d8a d763bd6d96 d7697c8b13 d7797196b4 d79c834768 d7b34e5d73 d7bb6b37a7 d7c7e064a6 d7fbf545b3 d82a0aa15b d847e24abd d8596701b7 d86101499c d87069ba86 d87160957b d874654b52 d88a403092 d8aee40f3f d8e77a222d d8eb07c381 d9010348a1 d90e3cf281 d92532c7b2 d927fae122 d95707bca8 d973b31c00 d991cb471d d992c69d37 d99d770820 d9b63abc11 d9db6f1983 d9e52be2d2 d9edc82650 da01070697 da070ea4b7 da080507b9 da0e944cc4 da28d94ff4 da5d78b9d1 da6003fc72 da690fee9f da6c68708f da7a816676 dac361e828 dac71659b8 dad980385d daebc12b77 db0968cdd3 db231a7100 db59282ace db7f267c3f dba35b87fd dbba735a50 dbca076acd dbd66dc3ac dbdc3c292b dbf4a5b32b dbfc417d28 dc1745e0a2 dc32a44804 dc34b35e30 dc504a4f79 dc704dd647 dc71bc6918 dc7771b3be dcf8c93617 dd0f4c9fb9 dd415df125 dd601f9a3f dd61d903df dd77583736 dd8636bd8b dd9fe6c6ac ddb2da4c14 ddcd450d47 dde8e67fb4 ddfc3f04d3 de2ab79dfa de2f35b2fd de30990a51 de36b216da de37403340 de46e4943b de4ddbccb1 de5e480f05 de6a9382ca de74a601d3 de827c510d ded6069f7b defb71c741 df01f277f1 df05214b82 df0638b0a0 df11931ffe df1b0e4620 df20a8650d df2bc56d7c df365282c6 df39a0d9df df3c430c24 df5536cfb9 df59cfd91d df5e2152b3 df741313c9 df7626172f df8ad5deb9 df96aa609a df9705605c df9c91c4da dfc0d3d27a dfdbf91a99 e00baaae9b e0a938c6e7 e0b2ceee6f e0bdb5dfae e0be1f6e17 e0c478f775 e0de82caa7 e0f217dd59 e0f7208874 e0fb58395e e1194c2e9d e11adcd05d e128124b9d e1495354e4 e1561d6d4b e158805399 e16945b951 e19edcd34b e1a1544285 e1ab7957f4 e1d26d35be e1e957085b e1f14510fa e214b160f4 e2167379b8 e21acb20ab e221105579 e22ddf8a1b e22de45950 e22ffc469b e23cca5244 e252f46f0b e25fa6cf39 e26e486026 e275760245 e27bbedbfe e29e9868a8 e2b37ff8af e2b608d309 e2bef4da9a e2c87a6421 e2ea25542c e2fb1d6497 e2fcc99117 e33c18412a e348377191 e352cb59c8 e36ac982f0 e391bc981e e39e3e0a06 e3bf38265f e3d5b2cd21 e3d60e82d5 e3e3245492 e3e4134877 e3f4635e03 e4004ee048 e402d1afa5 e415093d27 e41ceb5d81 e424653b78 e42b6d3dbb e42d60f0d4 e436d0ff1e e43d7ae2c5 e4428801bc e44e0b4917 e470345ede e48e8b4263 e4922e3726 e4936852bb e495f32c60 e499228f26 e4af66e163 e4b2095f58 e4d19c8283 e4d4872dab e4e2983570 e4eaa63aab e4ef0a3a34 e4f8e5f46e e4ffb6d0dd e53e21aa02 e57f4f668b e588433c1e e597442c99 e5abc0e96b e5be628030 e5ce96a55d e5d6b70a9f e5fde1574c e625e1d27b e6261d2348 e6267d46bc e6295f223f e63463d8c6 e6387bd1e0 e653883384 e65f134e0b e668ef5664 e672ccd250 e674510b20 e676107765 e699da0cdf e6be243065 e6deab5e0b e6f065f2b9 e71629e7b5 e72a7d7b0b e72f6104e1 e75a466eea e76c55933f e7784ec8ad e78922e5e6 e78d450a9c e7c6354e77 e7c8de1fce e7ea10db28 e803918710 e8073a140b e828dd02db e845994987 e8485a2615 e85c5118a7 e88b6736e4 e8962324e3 e8b3018d36 e8cee8bf0b e8d97ebece e8da49ea6a e8ed1a3ccf e8f7904326 e8f8341dec e8fa21eb13 e90c10fc4c e914b8cac8 e92b6bfea4 e92e1b7623 e93f83e512 e9422ad240 e9460b55f9 e9502628f6 e950befd5f e9582bdd1b e95e5afe0f e97cfac475 e98d57d99c e98eda8978 e99706b555 e9bc0760ba e9d3c78bf3 e9ec1b7ea8 ea065cc205 ea138b6617 ea16d3fd48 ea2545d64b ea286a581c ea320da917 ea345f3627 ea3b94a591 ea444a37eb ea4a01216b ea5672ffa8 eaa99191cb eaab4d746c eac7a59bc1 ead5d3835a eaec65cfa7 eaed1a87be eb2f821c6f eb383cb82e eb6992fe02 eb6ac20a01 eb6d7ab39e eb7921facd eb8fce51a6 ebbb90e9f9 ebbf5c9ee1 ebc4ec32e6 ebe56e5ef8 ec1299aee4 ec139ff675 ec193e1a01 ec28252938 ec387be051 ec3d4fac00 ec4186ce12 ec579c2f96 ecae59b782 ecb33a0448 ece6bc9e92 ecfedd4035 ecfff22fd6 ed3291c3d6 ed3cd5308d ed3e6fc1a5 ed72ae8825 ed7455da68 ed844e879f ed8f814b2b ed911a1f63 ed9ff4f649 eda8ab984b edb8878849 edbfdfe1b4 edd22c46a2 edd663afa3 ede3552eae edeab61ee0 ee07583fc0 ee316eaed6 ee3f509537 ee40a1e491 ee4bf100f1 ee6f9b01f9 ee947ed771 ee9706ac7f ee9a7840ae eeb90cb569 eebf45e5c5 eeed0c7d73 ef0061a309 ef07f1a655 ef0a8e8f35 ef232a2aed ef308ad2e9 ef44945428 ef45ce3035 ef5dde449d ef5e770988 ef6359cea3 ef65268834 ef6cb5eae0 ef78972bc2 ef8cfcfc4f ef96501dd0 ef9a2e976b efb24f950f efce0c1868 efe5ac6901 efe828affa efea4e0523 f0268aa627 f0483250c8 f04cf99ee6 f05b189097 f08928c6d3 f09d74856f f0a7607d63 f0ad38da27 f0c34e1213 f0c7f86c29 f0dfa18ba7 f0eb3179f7 f119bab27d f14409b6a3 f1489baff4 f14c18cf6a f15c607b92 f1af214222 f1b77bd309 f1ba9e1a3e f1d99239eb f1dc710cf4 f1ec5c08fa f22648fe12 f22d21f1f1 f233257395 f23e95dbe5 f2445b1572 f253b3486d f277c7a6a4 f2ab2b84d6 f2b7c9b1f3 f2b83d5ce5 f2c276018f f2cfd94d64 f2dd6e3add f2e7653f16 f2f333ad06 f2f55d6713 f2fdb6abec f305a56d9f f3085d6570 f3325c3338 f3400f1204 f34497c932 f34a56525e f36483c824 f3704d5663 f3734c4913 f38e5aa5b4 f3986fba44 f3a0ffc7d9 f3b24a7d28 f3e6c35ec3 f3fc0ea80b f40a683fbe f4207ca554 f4377499c2 f46184f393 f46c2d0a6d f46c364dca f46f7a0b63 f46fe141b0 f470b9aeb0 f47eb7437f f48b535719 f49e4866ac f4aa882cfd f4daa3dbd5 f4dd51ac35 f507a1b9dc f51c5ac84b f52104164b f54c67b9bb f5966cadd2 f5bddf5598 f5d85cfd17 f5e2e7d6a0 f5f051e9b4 f5f8a93a76 f6283e8af5 f635e9568b f6474735be f659251be2 f66981af4e f6708fa398 f697fe8e8f f6adb12c42 f6c7906ca4 f6cd0a8016 f6d6f15ae7 f6e501892c f6f59d986f f6fe8c90a5 f714160545 f74c3888d7 f7782c430e f7783ae5f2 f77ab47923 f788a98327 f7961ac1f0 f7a71e7574 f7a8521432 f7afbf4947 f7b7cd5f44 f7cf4b4a39 f7d49799ad f7e0c9bb83 f7e5b84928 f7e6bd58be f7f2a38ac6 f7f6cb2d6d f83f19e796 f85796a921 f8603c26b2 f8819b42ec f891f8eaa1 f89288d10c f895ae8cc1 f8b4ac12f1 f8c3fb2b01 f8c8de2764 f8db369b40 f8fcb6a78c f94aafdeef f95d217b70 f9681d5103 f9750192a4 f9823a32c2 f991ddb4c2 f99d535567 f9ae3d98b7 f9b6217959 f9bd1fabf5 f9c68eaa64 f9d3e04c4f f9daf64494 f9e4cc5a0a f9ea6b7f31 f9f3852526 fa04c615cf fa08e00a56 fa4370d74d fa67744af3 fa88d48a92 fa8b904cc9 fa9526bdf1 fa9b9d2426 fad633fbe1 faf5222dc3 faff0e15f1 fb08c64e8c fb23455a7f fb2e19fa6e fb34dfbb77 fb47fcea1e fb49738155 fb4cbc514b fb4e6062f7 fb5ba7ad6e fb63cd1236 fb81157a07 fb92abdaeb fba22a6848 fbaca0c9df fbc645f602 fbd77444cd fbe53dc8e8 fbe541dd73 fbe8488798 fbfd25174f fc28cb305e fc33b1ffd6 fc6186f0bb fc918e3a40 fc96cda9d8 fc9832eea4 fcb10d0f81 fcd20a2509 fcf637e3ab fcfd81727f fd31890379 fd33551c28 fd542da05e fd6789b3fe fd77828200 fd7af75f4d fdb28d0fbb fdb3d1fb1e fdb8b04124 fdc6e3d581 fdfce7e6fc fe0f76d41b fe24b0677d fe3c02699d fe58b48235 fe6a5596b8 fe6c244f63 fe7afec086 fe985d510a fe9db35d15 fea8ffcd36 feb1080388 fed208bfca feda5ad1c2 feec95b386 ff15a5eff6 ff204daf4b ff25f55852 ff2ada194f ff2ce142e8 ff49d36d20 ff5a1ec4f3 ff66152b25 ff692fdc56 ff773b1a1e ff97129478 ffb904207d ffc43fc345 fffe5f8df6 ================================================ FILE: merge_lora_weights_and_save_hf_model.py ================================================ import argparse import glob import os import sys import cv2 import numpy as np import torch import torch.nn.functional as F import transformers from peft import LoraConfig, get_peft_model from transformers import AutoTokenizer from model.VISA import VISAForCausalLM from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN """ python merge_lora_weights_and_save_hf_model.py \ --version /mnt/nlp-ali/usr/yancilin/clyan-data-2/video-llm/Chat-UniVi/Chat-UniVi \ --weight /mnt/public03/dataset/ovis/rgvos/visa7b/ckpt_model/pytorch_model15000.bin \ --save_path /mnt/public03/dataset/ovis/rgvos/visa7b/ckpt_model/hf_model """ def parse_args(args): parser = argparse.ArgumentParser( description="merge lora weights and save model with hf format" ) parser.add_argument( "--version", default="Chat-UniVi/Chat-UniVi", type=str, required=True ) parser.add_argument( "--weight", default="/path/to/visa/pytorch_model.bin", type=str, required=True ) parser.add_argument( "--save_path", default="/path/to/hf_model", type=str, required=True ) parser.add_argument("--precision", default="bf16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) parser.add_argument("--out_dim", default=256, type=int) parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str, ) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument("--lora_alpha", default=16, type=int) parser.add_argument("--lora_dropout", default=0.05, type=float) parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--train_mask_decoder", action="store_true", default=True) parser.add_argument("--use_mm_start_end", action="store_true", default=False) parser.add_argument("--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"], ) return parser.parse_args(args) def main(args): args = parse_args(args) # Create model tokenizer = transformers.AutoTokenizer.from_pretrained( args.version, cache_dir = None, model_max_length = args.model_max_length, padding_side = "right", use_fast = False, ) tokenizer.pad_token = tokenizer.unk_token num_added_tokens = tokenizer.add_tokens("[SEG]") args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] if args.use_mm_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model_args = { "train_mask_decoder": args.train_mask_decoder, "out_dim" : args.out_dim, "seg_token_idx" : args.seg_token_idx, "vision_tower" : args.vision_tower, } torch_dtype = torch.float32 if args.precision == "bf16": torch_dtype = torch.bfloat16 elif args.precision == "fp16": torch_dtype = torch.half model = VISAForCausalLM.from_pretrained(args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args, ) model.config.eos_token_id = tokenizer.eos_token_id model.config.bos_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id model.get_model().initialize_vision_modules(model.get_model().config) vision_tower = model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) model.get_model().initialize_lisa_modules(model.get_model().config) lora_r = args.lora_r if lora_r > 0: def find_linear_layers(model, lora_target_modules): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if ( isinstance(module, cls) and all([x not in name for x in ["visual_model", "vision_tower", "mm_projector", "text_hidden_fcs", ]]) and any([x in name for x in lora_target_modules]) ): lora_module_names.add(name) return sorted(list(lora_module_names)) lora_alpha = args.lora_alpha lora_dropout = args.lora_dropout lora_target_modules = find_linear_layers(model, args.lora_target_modules.split(","), ) lora_config = LoraConfig( r = lora_r, lora_alpha = lora_alpha, target_modules = lora_target_modules, lora_dropout = lora_dropout, bias = "none", task_type = "CAUSAL_LM", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() model.resize_token_embeddings(len(tokenizer)) state_dict = torch.load(args.weight, map_location="cpu") model.load_state_dict(state_dict, strict=True) model = model.merge_and_unload() state_dict = {} for k, v in model.state_dict().items(): if "vision_tower" not in k: state_dict[k] = v model.save_pretrained(args.save_path, state_dict=state_dict) tokenizer.save_pretrained(args.save_path) if __name__ == "__main__": main(sys.argv[1:]) ================================================ FILE: model/VISA.py ================================================ from typing import List import torch import torch.nn as nn import torch.nn.functional as F from transformers import BitsAndBytesConfig, CLIPVisionModel from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN from .univi.model.language_model.llama import ChatUniViLlamaForCausalLM, ChatUniViLlamaModel from .segment_anything import build_sam_vit_h from model.univi.constants import IMAGE_TOKEN_INDEX import time def dice_loss( inputs : torch.Tensor, targets : torch.Tensor, num_masks: float, scale : float =1000, eps : float =1e-6, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1, 2) targets = targets.flatten(1, 2) numerator = 2 * (inputs / scale * targets).sum(-1) denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) loss = 1 - (numerator + eps) / (denominator + eps) loss = loss.sum() / (num_masks + 1e-8) return loss def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) return loss class VisaMetaModel: def __init__( self, config, **kwargs, ): super(VisaMetaModel, self).__init__(config) self.config = config if not hasattr(self.config, "train_mask_decoder"): self.config.train_mask_decoder = kwargs["train_mask_decoder"] self.config.out_dim = kwargs["out_dim"] self.vision_pretrained = kwargs.get("vision_pretrained", None) else: self.vision_pretrained = kwargs.get("vision_pretrained", None) self.initialize_lisa_modules(self.config) def initialize_lisa_modules(self, config): # SAM self.visual_model = build_sam_vit_h(self.vision_pretrained) for param in self.visual_model.parameters(): param.requires_grad = False if config.train_mask_decoder: self.visual_model.mask_decoder.train() for param in self.visual_model.mask_decoder.parameters(): param.requires_grad = True # Projection layer in_dim = config.hidden_size out_dim = config.out_dim text_fc = [ nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0), ] self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) self.text_hidden_fcs.train() for param in self.text_hidden_fcs.parameters(): param.requires_grad = True class VisaModel(VisaMetaModel, ChatUniViLlamaModel): def __init__( self, config, **kwargs, ): super(VisaModel, self).__init__(config, **kwargs) self.config.use_cache = False self.config.vision_tower = self.config.mm_vision_tower self.config.mm_vision_select_feature = "patch" self.config.image_aspect_ratio = "square" self.config.image_grid_pinpoints = None self.config.tune_mm_mlp_adapter = False self.config.freeze_mm_mlp_adapter = True self.config.pretrain_mm_mlp_adapter = None self.config.mm_use_im_patch_token = False class VISAForCausalLM(ChatUniViLlamaForCausalLM): def __init__( self, config, **kwargs, ): if not hasattr(config, "train_mask_decoder"): config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True) config.mm_vision_tower = kwargs.get("vision_tower", "openai/clip-vit-large-patch14") self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) else: config.mm_vision_tower = config.vision_tower self.seg_token_idx = kwargs.pop("seg_token_idx") super().__init__(config) self.model = VisaModel(config, **kwargs) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_visual_embs(self, pixel_values: torch.FloatTensor): with torch.no_grad(): image_embeddings = self.model.visual_model.image_encoder(pixel_values) return image_embeddings def forward(self, **kwargs): if "past_key_values" in kwargs: return super().forward(**kwargs) return self.model_forward(**kwargs) def model_forward( self, images: torch.FloatTensor, images_clip: torch.FloatTensor, input_ids: torch.LongTensor, labels: torch.LongTensor, attention_masks: torch.LongTensor, offset: torch.LongTensor, masks_list: List[torch.FloatTensor], label_list: List[torch.Tensor], resize_list: List[tuple], conversation_list: List[str], num_frame_list: List[int], num_conv_list: List[int], inference: bool = False, **kwargs, ): batch_size = len(images) image_embeddings = self.get_visual_embs(torch.cat(images,dim=0)) assert image_embeddings.shape[0] == batch_size assert batch_size == len(offset) - 1 for batch_idx in range(batch_size): assert num_conv_list[batch_idx] == offset[batch_idx + 1] - offset[batch_idx] if inference: length = input_ids.shape[0] assert len(images_clip) == 1, f'Inference only supports one video, but got {len(images_clip)} videos.' images_clip = [ images_clip[0].unsqueeze(0).expand(length, -1, -1, -1, -1).contiguous().flatten(0,1) ] output_i = super().forward( images=images_clip, attention_mask=attention_masks, input_ids=input_ids, output_hidden_states=True, ) torch.cuda.empty_cache() output_hidden_states = output_i.hidden_states output = None num_image_ori_token = (input_ids[0] == IMAGE_TOKEN_INDEX).sum() assert all( [ (input_ids[i] == IMAGE_TOKEN_INDEX).sum() == num_image_ori_token for i in range(length) ] ) token_add = 111 * num_image_ori_token seg_token_mask = input_ids[:, 1:] == self.seg_token_idx seg_token_mask = torch.cat([seg_token_mask, torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(), ], dim=1, ) seg_token_mask = torch.cat([torch.zeros((seg_token_mask.shape[0], token_add)).bool().cuda(), seg_token_mask], dim=1, ) all_conv_seg_token_num = seg_token_mask.sum(dim=1).tolist() else: images_clip_list = [] for batch_idx in range(batch_size): bs_conv_num = num_conv_list[batch_idx] images_clip_i = images_clip[batch_idx].unsqueeze(0).expand(bs_conv_num, -1, -1, -1, -1).contiguous() images_clip_list.append(images_clip_i) images_clip_list = [i.flatten(0,1) for i in images_clip_list] output = super().forward( images=images_clip_list, attention_mask=attention_masks, input_ids=input_ids, labels=labels, output_hidden_states=True, ) output_hidden_states = output.hidden_states seg_token_mask = output.labels[..., 1:] == self.seg_token_idx seg_token_mask = torch.cat([seg_token_mask, torch.zeros((seg_token_mask.shape[0], 1), device=output.labels.device).bool(), ], dim=1, ) all_conv_seg_token_num = seg_token_mask.sum(dim=1).tolist() assert len(self.model.text_hidden_fcs) == 1 pred_embeddings = self.model.text_hidden_fcs[0](output_hidden_states[-1][seg_token_mask]) seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ] seg_token_offset = seg_token_counts.cumsum(-1) seg_token_offset = torch.cat( [torch.zeros(1).long().cuda(), seg_token_offset], dim=0 ) seg_token_offset = seg_token_offset[offset] pred_embeddings_ = [] for i in range(len(seg_token_offset) - 1): start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1] pred_embeddings_.append(pred_embeddings[start_i:end_i]) pred_embeddings = pred_embeddings_ assert len(pred_embeddings) == batch_size multimask_output = False pred_masks = [] for i in range(batch_size): ( sparse_embeddings, dense_embeddings, ) = self.model.visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embeddings[i].unsqueeze(1), ) sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) low_res_masks, iou_predictions = self.model.visual_model.mask_decoder( image_embeddings=image_embeddings[i].unsqueeze(0), image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) pred_mask = self.model.visual_model.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=label_list[i].shape, ) pred_masks.append(pred_mask[:, 0]) model_output = output gt_masks = [mm.flatten(0, 1) for mm in masks_list] if inference: return { "pred_masks": pred_masks, "gt_masks": gt_masks, } output = model_output.logits ce_loss = model_output.loss ce_loss = ce_loss * self.ce_loss_weight mask_bce_loss = 0 mask_dice_loss = 0 num_masks = 0 for batch_idx in range(batch_size): gt_mask = gt_masks[batch_idx] pred_mask = pred_masks[batch_idx] assert ( gt_mask.shape[0] == pred_mask.shape[0] ), "gt_mask.shape: {}, pred_mask.shape: {}".format( gt_mask.shape, pred_mask.shape ) mask_bce_loss += ( sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) * gt_mask.shape[0] ) mask_dice_loss += ( dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) * gt_mask.shape[0] ) num_masks += gt_mask.shape[0] mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) mask_loss = mask_bce_loss + mask_dice_loss loss = ce_loss + mask_loss return { "loss": loss, "ce_loss": ce_loss, "mask_bce_loss": mask_bce_loss, "mask_dice_loss": mask_dice_loss, "mask_loss": mask_loss, } def evaluate(self, *args, **kwargs): raise NotImplementedError("This method is not implemented.") ================================================ FILE: model/llava/__init__.py ================================================ from .model import LlavaLlamaForCausalLM ================================================ FILE: model/llava/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" ================================================ FILE: model/llava/conversation.py ================================================ import dataclasses from enum import Enum, auto from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if "mmtag" in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new( pil_img.mode, (width, width), background_color ) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new( pil_img.mode, (height, height), background_color ) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError( f"Invalid image_process_mode: {image_process_mode}" ) max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' ret.append([img_str, None]) msg = msg.replace("", "").strip() if len(msg) > 0: ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version, ) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [ [x, y[0] if type(y) is tuple else y] for x, y in self.messages ], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ( "Human", "What are the key differences between renewable and non-renewable energy sources?", ), ( "Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", ), ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llava_llama_2 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mpt = Conversation( system="""<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_llava_plain = Conversation( system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", ) conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?")), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v0_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), messages=(), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", version="v0_mmtag", ) conv_llava_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llava_v1_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", version="v1_mmtag", ) default_conversation = conv_vicuna_v0 conv_templates = { "default": conv_vicuna_v0, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "llama_2": conv_llama_2, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, "v0_mmtag": conv_llava_v0_mmtag, "llava_v1": conv_llava_v1, "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "mpt": conv_mpt, } if __name__ == "__main__": print(default_conversation.get_prompt()) ================================================ FILE: model/llava/mm_utils.py ================================================ import base64 from io import BytesIO import torch from PIL import Image from transformers import StoppingCriteria from .constants import IMAGE_TOKEN_INDEX def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def process_images(images, image_processor, model_cfg): return image_processor(images, return_tensors="pt")["pixel_values"] def tokenizer_image_token( prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None ): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if ( len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith("checkpoint-"): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if ( len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id ): cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__( self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids ] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0] :] == keyword_id: return True outputs = self.tokenizer.batch_decode( output_ids[:, -offset:], skip_special_tokens=True )[0] for keyword in self.keywords: if keyword in outputs: return True return False ================================================ FILE: model/llava/model/__init__.py ================================================ from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM ================================================ FILE: model/llava/model/apply_delta.py ================================================ """ Usage: python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta """ import argparse import torch from llava import LlavaLlamaForCausalLM from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) print("Loading delta") delta = LlavaLlamaForCausalLM.from_pretrained( delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): assert name in [ "model.mm_projector.weight", "model.mm_projector.bias", ], f"{name} not in base model" continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] else: assert name in [ "model.embed_tokens.weight", "lm_head.weight", ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" bparam = base.state_dict()[name] param.data[: bparam.shape[0], : bparam.shape[1]] += bparam print("Saving target model") delta.save_pretrained(target_model_path) delta_tokenizer.save_pretrained(target_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) args = parser.parse_args() apply_delta(args.base_model_path, args.target_model_path, args.delta_path) ================================================ FILE: model/llava/model/builder.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import torch from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN) from llava.model import * from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig) def load_pretrained_model( model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", ): kwargs = {"device_map": device_map} if load_8bit: kwargs["load_in_8bit"] = True elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) else: kwargs["torch_dtype"] = torch.float16 if "llava" in model_name.lower(): # Load LLaVA model if "lora" in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print("Loading LLaVA from base model...") model = LlavaLlamaForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs ) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) model.model.embed_tokens.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) print("Loading additional LLaVA weights...") if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): non_lora_trainables = torch.load( os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu", ) else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder ) return torch.load(cache_file, map_location="cpu") non_lora_trainables = load_from_hf( model_path, "non_lora_trainables.bin" ) non_lora_trainables = { (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items() } if any(k.startswith("model.model.") for k in non_lora_trainables): non_lora_trainables = { (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items() } model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print("Loading LoRA weights...") model = PeftModel.from_pretrained(model, model_path) print("Merging LoRA weights...") model = model.merge_and_unload() print("Model is loaded...") elif model_base is not None: # this may be mm projector only print("Loading LLaVA from base model...") if "mpt" in model_name.lower(): if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")): shutil.copyfile( os.path.join(model_base, "configuration_mpt.py"), os.path.join(model_path, "configuration_mpt.py"), ) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained( model_path, trust_remote_code=True ) model = LlavaMPTForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaLlamaForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs ) mm_projector_weights = torch.load( os.path.join(model_path, "mm_projector.bin"), map_location="cpu" ) mm_projector_weights = { k: v.to(torch.float16) for k, v in mm_projector_weights.items() } model.load_state_dict(mm_projector_weights, strict=False) else: if "mpt" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMPTForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", ) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print("Convert to FP16...") model.to(torch.float16) else: use_fast = False if "mpt" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) image_processor = None if "llava" in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(device="cuda", dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len ================================================ FILE: model/llava/model/consolidate.py ================================================ """ Usage: python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate """ import argparse import torch from llava.model import * from llava.model.utils import auto_upgrade from transformers import AutoModelForCausalLM, AutoTokenizer def consolidate_ckpt(src_path, dst_path): print("Loading model") auto_upgrade(src_path) src_model = AutoModelForCausalLM.from_pretrained( src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) src_model.save_pretrained(dst_path) src_tokenizer.save_pretrained(dst_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True) parser.add_argument("--dst", type=str, required=True) args = parser.parse_args() consolidate_ckpt(args.src, args.dst) ================================================ FILE: model/llava/model/language_model/llava_llama.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, LlamaModel) from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( input_ids, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output if self.training: output_hidden_states = outputs.hidden_states else: output_hidden_states = hidden_states return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=output_hidden_states, # outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, images=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": images, } ) return model_inputs AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) ================================================ FILE: model/llava/model/language_model/llava_mpt.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from typing import List, Optional, Tuple import torch import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel class LlavaMPTConfig(MPTConfig): model_type = "llava_mpt" class LlavaMPTModel(LlavaMetaModel, MPTModel): config_class = LlavaMPTConfig def __init__(self, config: MPTConfig): config.hidden_size = config.d_model super(LlavaMPTModel, self).__init__(config) def embed_tokens(self, x): return self.wte(x) class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMPTConfig supports_gradient_checkpointing = True def __init__(self, config): super(MPTForCausalLM, self).__init__(config) if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = LlavaMPTModel(config) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == "inv_sqrt_d_model": logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError( f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." ) self.logit_scale = logit_scale def get_model(self): return self.transformer def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlavaMPTModel): module.gradient_checkpointing = value def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, images=None, ): return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache ( input_ids, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images ) outputs = self.transformer( input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, ) # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 logits = F.linear( outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight, ) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." ) logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): if inputs_embeds is not None: raise NotImplementedError("inputs_embeds is not implemented for MPT yet") attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( "MPT does not support generation with right padding." ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get("use_cache") == False: raise NotImplementedError( "MPT with prefix_lm=True does not support use_cache=False." ) else: prefix_mask = None return { "input_ids": input_ids, "attention_mask": attention_mask, "prefix_mask": prefix_mask, "sequence_id": sequence_id, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "images": kwargs.get("images", None), } AutoConfig.register("llava_mpt", LlavaMPTConfig) AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) ================================================ FILE: model/llava/model/language_model/mpt/adapt_tokenizer.py ================================================ from typing import Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] NUM_SENTINEL_TOKENS: int = 100 def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): """Adds sentinel tokens and padding token (if missing). Expands the tokenizer vocabulary to include sentinel tokens used in mixture-of-denoiser tasks as well as a padding token. All added tokens are added as special tokens. No tokens are added if sentinel tokens and padding token already exist. """ sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)] tokenizer.add_tokens(sentinels_to_add, special_tokens=True) if tokenizer.pad_token is None: tokenizer.add_tokens("", special_tokens=True) tokenizer.pad_token = "" assert tokenizer.pad_token_id is not None sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)]) _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids tokenizer.sentinel_token_ids = _sentinel_token_ids class AutoTokenizerForMOD(AutoTokenizer): """AutoTokenizer + Adaptation for MOD. A simple wrapper around AutoTokenizer to make instantiating an MOD-adapted tokenizer a bit easier. MOD-adapted tokenizers have sentinel tokens (e.g., ), a padding token, and a property to get the token ids of the sentinel tokens. """ @classmethod def from_pretrained(cls, *args, **kwargs): """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) return tokenizer ================================================ FILE: model/llava/model/language_model/mpt/attention.py ================================================ """Attention layers.""" import math import warnings from typing import Optional import torch import torch.nn as nn from einops import rearrange from packaging import version from torch import nn from .norm import LPLayerNorm def _reset_is_causal( num_query_tokens: int, num_key_tokens: int, original_is_causal: bool ): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError( "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." ) else: return False return original_is_causal def scaled_multihead_dot_product_attention( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) if past_key_value is not None: if len(past_key_value) != 0: k = torch.cat([past_key_value[0], k], dim=3) v = torch.cat([past_key_value[1], v], dim=2) past_key_value = (k, v) (b, _, s_q, d) = q.shape s_k = k.size(-1) if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - s_q) _s_k = max(0, attn_bias.size(3) - s_k) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if ( attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) ): raise RuntimeError( f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." ) attn_weight = attn_weight + attn_bias min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: if attn_bias is not None: warnings.warn( "Propogating key_padding_mask to the attention module " + "and applying it within the attention module can cause " + "unneccessary computation/memory usage. Consider integrating " + "into attn_bias once and passing that to each attention " + "module instead." ) attn_weight = attn_weight.masked_fill( ~key_padding_mask.view((b, 1, 1, s_k)), min_val ) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) causal_mask = causal_mask.tril() causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: attn_weight = torch.nn.functional.dropout( attn_weight, p=dropout_p, training=training, inplace=True ) out = attn_weight.to(v.dtype).matmul(v) out = rearrange(out, "b h s d -> b s (h d)") if needs_weights: return (out, attn_weight, past_key_value) return (out, None, past_key_value) def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError( f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." ) if not tensor.is_cuda: raise TypeError( f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." ) def flash_attn_fn( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): try: from flash_attn import bert_padding, flash_attn_interface except: raise RuntimeError("Please install flash-attn==1.0.3.post0") check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: raise NotImplementedError(f"attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1) :] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( query, query_padding_mask ) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( key, key_padding_mask ) key_unpad = rearrange( key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads ) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) value_unpad = rearrange( value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads ) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) value_unpad = value_unpad.expand( value_unpad.size(0), n_heads, value_unpad.size(-1) ) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func( query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, ) output = bert_padding.pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen ) return (output, None, past_key_value) def triton_flash_attn_fn( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): try: from .flash_attn_triton import flash_attn_func except: _installed = False if version.parse(torch.__version__) < version.parse("2.0.0"): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func except: _installed = False if not _installed: raise RuntimeError( "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." ) check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") if needs_weights: raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: warnings.warn( "Propagating key_padding_mask to the attention module " + "and applying it within the attention module can cause " + "unnecessary computation/memory usage. Consider integrating " + "into attn_bias once and passing that to each attention " + "module instead." ) (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) attn_bias = attn_bias.masked_fill( ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min ) query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func( query, key, value, attn_bias, reset_is_causal, softmax_scale ) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) class MultiheadAttention(nn.Module): """Multi-head self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__( self, d_model: int, n_heads: int, attn_impl: str = "triton", clip_qkv: Optional[float] = None, qk_ln: bool = False, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, low_precision_layernorm: bool = False, verbose: int = 0, device: Optional[str] = None, ): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(self.d_model, device=device) self.k_ln = layernorm_class(self.d_model, device=device) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn( "While `attn_impl: triton` can be faster than `attn_impl: flash` " + "it uses more memory. When training larger models this can trigger " + "alloc retries which hurts performance. If encountered, we recommend " + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." ) elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn( "Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`." ) else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward( self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False, ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.chunk(3, dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn( query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, ) return (self.out_proj(context), attn_weights, past_key_value) class MultiQueryAttention(nn.Module): """Multi-Query self attention. Using torch or triton attention implemetation enables user to also use additive bias. """ def __init__( self, d_model: int, n_heads: int, attn_impl: str = "triton", clip_qkv: Optional[float] = None, qk_ln: bool = False, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, low_precision_layernorm: bool = False, verbose: int = 0, device: Optional[str] = None, ): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) self.attn_dropout_p = attn_pdrop self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) fuse_splits = (d_model, d_model + self.head_dim) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(d_model, device=device) self.k_ln = layernorm_class(self.head_dim, device=device) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn( "While `attn_impl: triton` can be faster than `attn_impl: flash` " + "it uses more memory. When training larger models this can trigger " + "alloc retries which hurts performance. If encountered, we recommend " + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." ) elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn( "Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`." ) else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True def forward( self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False, ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.split( [self.d_model, self.head_dim, self.head_dim], dim=2 ) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn( query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True, ) return (self.out_proj(context), attn_weights, past_key_value) def attn_bias_shape( attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id ): if attn_impl == "flash": return None elif attn_impl in ["torch", "triton"]: if alibi: if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) return (1, n_heads, 1, seq_len) elif prefix_lm or use_sequence_id: return (1, 1, seq_len, seq_len) return None else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") def build_attn_bias( attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 ): if attn_impl == "flash": return None elif attn_impl in ["torch", "triton"]: if alibi: (device, dtype) = (attn_bias.device, attn_bias.dtype) attn_bias = attn_bias.add( build_alibi_bias( n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype, ) ) return attn_bias else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") def gen_slopes(n_heads, alibi_bias_max=8, device=None): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) slopes = 1.0 / torch.pow(2, m) if _n_heads != n_heads: slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] return slopes.view(1, n_heads, 1, 1) def build_alibi_bias( n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None ): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( 1, 1, 1, seq_len ) if full: alibi_bias = alibi_bias - torch.arange( 1 - seq_len, 1, dtype=torch.int32, device=device ).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) ATTN_CLASS_REGISTRY = { "multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention, } ================================================ FILE: model/llava/model/language_model/mpt/blocks.py ================================================ """GPT Blocks used for the GPT Model.""" from typing import Dict, Optional, Tuple import torch import torch.nn as nn from .attention import ATTN_CLASS_REGISTRY from .norm import NORM_CLASS_REGISTRY class MPTMLP(nn.Module): def __init__( self, d_model: int, expansion_ratio: int, device: Optional[str] = None ): super().__init__() self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) self.act = nn.GELU(approximate="none") self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj._is_residual = True def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) class MPTBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict = { "attn_type": "multihead_attention", "attn_pdrop": 0.0, "attn_impl": "triton", "qk_ln": False, "clip_qkv": None, "softmax_scale": None, "prefix_lm": False, "attn_uses_sequence_id": False, "alibi": False, "alibi_bias_max": 8, }, resid_pdrop: float = 0.0, norm_type: str = "low_precision_layernorm", verbose: int = 0, device: Optional[str] = None, **kwargs ): del kwargs super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] self.norm_1 = norm_class(d_model, device=device) self.attn = attn_class( attn_impl=attn_config["attn_impl"], clip_qkv=attn_config["clip_qkv"], qk_ln=attn_config["qk_ln"], softmax_scale=attn_config["softmax_scale"], attn_pdrop=attn_config["attn_pdrop"], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device, ) self.norm_2 = norm_class(d_model, device=device) self.ffn = MPTMLP( d_model=d_model, expansion_ratio=expansion_ratio, device=device ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) def forward( self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) (b, attn_weights, past_key_value) = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, ) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) return (x, attn_weights, past_key_value) ================================================ FILE: model/llava/model/language_model/mpt/configuration_mpt.py ================================================ """A HuggingFace-style model configuration.""" from typing import Dict, Optional, Union from transformers import PretrainedConfig attn_config_defaults: Dict = { "attn_type": "multihead_attention", "attn_pdrop": 0.0, "attn_impl": "triton", "qk_ln": False, "clip_qkv": None, "softmax_scale": None, "prefix_lm": False, "attn_uses_sequence_id": False, "alibi": False, "alibi_bias_max": 8, } init_config_defaults: Dict = { "name": "kaiming_normal_", "fan_mode": "fan_in", "init_nonlinearity": "relu", "init_div_is_residual": True, "emb_init_std": None, "emb_init_uniform_lim": None, "init_std": None, "init_gain": 0.0, } class MPTConfig(PretrainedConfig): model_type = "mpt" def __init__( self, d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, expansion_ratio: int = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_pdrop: float = 0.0, emb_pdrop: float = 0.0, learned_pos_emb: bool = True, attn_config: Dict = attn_config_defaults, init_device: str = "cpu", logit_scale: Optional[Union[float, str]] = None, no_bias: bool = False, verbose: int = 0, embedding_fraction: float = 1.0, norm_type: str = "low_precision_layernorm", use_cache: bool = False, init_config: Dict = init_config_defaults, **kwargs, ): """The MPT configuration class. Args: d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. expansion_ratio (int): The ratio of the up/down scale in the MLP. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. emb_pdrop (float): The dropout probability for the embedding layer. learned_pos_emb (bool): Whether to use learned positional embeddings attn_config (Dict): A dictionary used to configure the model's attention module: attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention attn_pdrop (float): The dropout probability for the attention layers. attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, use the default scale of ``1/sqrt(d_keys)``. prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another bi-directionally. Tokens outside the prefix use causal attention. attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates which sub-sequence each token belongs to. Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. verbose (int): The verbosity level. 0 is silent. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use multiquery_attention (bool): Whether to use multiquery attention implementation. use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. init_std (float): The standard deviation of the normal distribution used to initialize the model, if using the baseline_ parameter initialization scheme. init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options """ self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.expansion_ratio = expansion_ratio self.max_seq_len = max_seq_len self.vocab_size = vocab_size self.resid_pdrop = resid_pdrop self.emb_pdrop = emb_pdrop self.learned_pos_emb = learned_pos_emb self.attn_config = attn_config self.init_device = init_device self.logit_scale = logit_scale self.no_bias = no_bias self.verbose = verbose self.embedding_fraction = embedding_fraction self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config if "name" in kwargs: del kwargs["name"] if "loss_fn" in kwargs: del kwargs["loss_fn"] super().__init__(**kwargs) self._validate_config() def _set_config_defaults(self, config, config_defaults): for k, v in config_defaults.items(): if k not in config: config[k] = v return config def _validate_config(self): self.attn_config = self._set_config_defaults( self.attn_config, attn_config_defaults ) self.init_config = self._set_config_defaults( self.init_config, init_config_defaults ) if self.d_model % self.n_heads != 0: raise ValueError("d_model must be divisible by n_heads") if any( ( prob < 0 or prob > 1 for prob in [ self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop, ] ) ): raise ValueError( "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" ) if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]: raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [ "torch", "triton", ]: raise NotImplementedError( "prefix_lm only implemented with torch and triton attention." ) if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [ "torch", "triton", ]: raise NotImplementedError( "alibi only implemented with torch and triton attention." ) if self.attn_config["attn_uses_sequence_id"] and self.attn_config[ "attn_impl" ] not in ["torch", "triton"]: raise NotImplementedError( "attn_uses_sequence_id only implemented with torch and triton attention." ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!" ) if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model": raise ValueError( f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." ) if self.init_config.get("name", None) is None: raise ValueError( f"self.init_config={self.init_config!r} 'name' needs to be set." ) if not self.learned_pos_emb and (not self.attn_config["alibi"]): raise ValueError( f"Positional information must be provided to the model using either learned_pos_emb or alibi." ) ================================================ FILE: model/llava/model/language_model/mpt/custom_embedding.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class SharedEmbedding(nn.Embedding): def forward(self, input: Tensor, unembed: bool = False) -> Tensor: if unembed: return F.linear(input, self.weight) return super().forward(input) ================================================ FILE: model/llava/model/language_model/mpt/flash_attn_triton.py ================================================ """ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py update imports to use 'triton_pre_mlir' *Experimental* implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. Caution: - This is an *experimental* implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident that there are none left for other head dimensions. Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn't. """ import math import torch import triton_pre_mlir as triton import triton_pre_mlir.language as tl @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _fwd_kernel( Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) if BIAS_TYPE == "vector": b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n elif BIAS_TYPE == "matrix": b_ptrs = ( Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) ) t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0, ) elif EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: qk += tl.where( offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf") ) if BIAS_TYPE != "none": if BIAS_TYPE == "vector": if EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load( b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0 ).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == "matrix": if EVEN_M & EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load( b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0, ).to(tl.float32) qk = qk * softmax_scale + bias m_ij = tl.maximum(tl.max(qk, 1), lse_i) p = tl.exp(qk - m_ij[:, None]) else: m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) acc_o_scale = tl.exp(m_i - m_ij) tl.store(t_ptrs, acc_o_scale) acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0, ) elif EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) p = p.to(v.dtype) acc_o += tl.dot(p, v) m_i = m_ij l_i_new = tl.exp(lse_i - m_ij) + l_ij lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) tl.store(t_ptrs, o_scale) o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store( out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), ) @triton.jit def _bwd_preprocess_do_o_dot( Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) o = tl.load( Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ).to(tl.float32) do = tl.load( DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @triton.jit def _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, ): if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) elif EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store( dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim) ) tl.store( dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim) ) @triton.jit def _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) if BIAS_TYPE == "vector": b_ptrs = Bias + offs_n elif BIAS_TYPE == "matrix": b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, ) return if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) elif EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load( k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m if EVEN_M & EVEN_HEADDIM: q = tl.load(q_ptrs) elif EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: q = tl.load( q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.dot(q, k, trans_b=True) if not EVEN_N: qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf")) if BIAS_TYPE != "none": tl.debug_barrier() if BIAS_TYPE == "vector": if EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to( tl.float32 ) bias = bias[None, :] elif BIAS_TYPE == "matrix": if EVEN_M & EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load( b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0, ).to(tl.float32) qk = qk * softmax_scale + bias if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) if BIAS_TYPE == "none": p = tl.exp(qk * softmax_scale - lse_i[:, None]) else: p = tl.exp(qk - lse_i[:, None]) if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) else: do = tl.load( do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ) dv += tl.dot(p.to(do.dtype), do, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) if not EVEN_HEADDIM: tl.debug_barrier() Di = tl.load(D + offs_m_curr) ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) dk += tl.dot(ds, q, trans_a=True) if not EVEN_M & EVEN_HEADDIM: tl.debug_barrier() if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, eviction_policy="evict_last") elif EVEN_HEADDIM: dq = tl.load( dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last", ) dq += tl.dot(ds, k) tl.store( dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last", ) else: dq = tl.load( dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last", ) dq += tl.dot(ds, k) tl.store( dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last", ) else: dq = tl.dot(ds, k) if EVEN_M & EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq) elif EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) else: tl.atomic_add( dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), ) dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom if BIAS_TYPE == "matrix": b_ptrs += BLOCK_M * stride_bm dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, ) def init_to_zero(name): return lambda nargs: nargs[name].zero_() @triton.autotune( configs=[ triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ"), ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ"), ), ], key=[ "CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM", ], ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_kernel( Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads Q += off_b * stride_qb + off_h * stride_qh K += off_b * stride_kb + off_h * stride_kh V += off_b * stride_vb + off_h * stride_vh DO += off_b * stride_dob + off_h * stride_doh DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh if BIAS_TYPE != "none": Bias += off_b * stride_bb + off_h * stride_bh D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) else: start_n = tl.program_id(0) _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert k.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d) assert d <= 128, "FlashAttention only support head dimensions up to 128" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(d) has_bias = bias is not None bias_type = "none" if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 if bias.stride(-1) != 1: bias = bias.contiguous() if bias.shape[2:] == (1, seqlen_k): bias_type = "vector" elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = "matrix" else: raise RuntimeError( "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)" ) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = ( (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) ) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty( (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32 ) tmp = torch.empty( (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32 ) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK = 128 num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_kernel[grid]( q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1 ) return (o, lse, softmax_scale) def _flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None ): if do.stride(-1) != 1: do = do.contiguous() (batch, seqlen_q, nheads, d) = q.shape (_, seqlen_k, _, _) = k.shape assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(d) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _bwd_preprocess_do_o_dot[grid]( o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, ) has_bias = bias is not None bias_type = "none" if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 assert bias.stride(-1) == 1 if bias.shape[2:] == (1, seqlen_k): bias_type = "vector" elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = "matrix" else: raise RuntimeError( "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)" ) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = ( (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) ) grid = lambda META: ( triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads, ) _bwd_kernel[grid]( q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM ) dq.copy_(dq_accum) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): """ qkv: (batch, seqlen, 3, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) """ if qkv.stride(-1) != 1: qkv = qkv.contiguous() (o, lse, ctx.softmax_scale) = _flash_attn_forward( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale, ) ctx.save_for_backward(qkv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (qkv, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[ 1 ], "FlashAttention does not support bias gradient yet" with torch.inference_mode(): dqkv = torch.empty_like(qkv) _flash_attn_backward( do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return (dqkv, None, None, None) flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): """ q: (batch, seqlen_q, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] (o, lse, ctx.softmax_scale) = _flash_attn_forward( q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale, ) ctx.save_for_backward(q, kv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, kv, o, lse, bias) = ctx.saved_tensors if len(ctx.needs_input_grad) >= 3: assert not ctx.needs_input_grad[ 2 ], "FlashAttention does not support bias gradient yet" with torch.inference_mode(): dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward( do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return (dq, dkv, None, None, None) flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) k, v: (batch_size, seqlen_k, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] (o, lse, ctx.softmax_scale) = _flash_attn_forward( q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, k, v, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): (q, k, v, o, lse, bias) = ctx.saved_tensors assert not ctx.needs_input_grad[ 3 ], "FlashAttention does not support bias gradient yet" with torch.inference_mode(): dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) _flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return (dq, dk, dv, None, None, None) flash_attn_func = FlashAttnFunc.apply ================================================ FILE: model/llava/model/language_model/mpt/hf_prefixlm_converter.py ================================================ """Converts Huggingface Causal LM to Prefix LM. Conversion does lightweight surgery on a HuggingFace Causal LM to convert it to a Prefix LM. Prefix LMs accepts a `bidirectional_mask` input in `forward` and treat the input prompt as the prefix in `generate`. """ import math import warnings from types import MethodType from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import ( BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss) from transformers.models.bloom.modeling_bloom import \ _expand_mask as _expand_mask_bloom from transformers.models.bloom.modeling_bloom import \ _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import \ _expand_mask as _expand_mask_opt from transformers.models.opt.modeling_opt import \ _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = ( GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, ) CAUSAL_GPT_TYPES = Union[ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM ] def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES: """Converts a GPT-style Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, _SUPPORTED_GPT_MODELS) assert ( model.config.add_cross_attention == False ), "Only supports GPT-style decoder-only models" def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: """Helper that gets a list of the model's attention modules. Each module has a `bias` buffer used for causal masking. The Prefix LM conversion adds logic to dynamically manipulate these biases to support Prefix LM attention masking. """ attn_modules = [] if isinstance(model, GPTNeoXForCausalLM): blocks = model.gpt_neox.layers else: blocks = model.transformer.h for block in blocks: if isinstance(model, GPTNeoForCausalLM): if block.attn.attention_type != "global": continue attn_module = block.attn.attention elif isinstance(model, GPTNeoXForCausalLM): attn_module = block.attention else: attn_module = block.attn attn_modules.append(attn_module) return attn_modules setattr(model, "_original_forward", getattr(model, "forward")) setattr(model, "_original_generate", getattr(model, "generate")) def forward( self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, bidirectional_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): """Wraps original forward to enable PrefixLM attention.""" def call_og_forward(): if isinstance(self, GPTNeoXForCausalLM): return self._original_forward( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) else: return self._original_forward( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if bidirectional_mask is None: return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) attn_modules = _get_attn_modules(model) (b, s) = bidirectional_mask.shape max_length = attn_modules[0].bias.shape[-1] if s > max_length: raise ValueError( f"bidirectional_mask sequence length (={s}) exceeds the " + f"max length allowed by the model ({max_length})." ) assert s <= max_length if s < max_length: pad = torch.zeros( (int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device, ) bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1) bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1) for attn_module in attn_modules: attn_module.bias.data = torch.logical_or( attn_module.bias.data, bidirectional ) output = call_og_forward() for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM attention.""" attn_modules = _get_attn_modules(model) for attn_module in attn_modules: attn_module.bias.data[:] = 1 output = self._original_generate(*args, **kwargs) for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output setattr(model, "forward", MethodType(forward, model)) setattr(model, "generate", MethodType(generate, model)) setattr(model, "_prefix_lm_converted", True) return model def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM: """Converts a BLOOM Causal LM to a Prefix LM. Supported HuggingFace model classes: - `BloomForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, BloomForCausalLM) assert ( model.config.add_cross_attention == False ), "Only supports BLOOM decoder-only models" def _prepare_attn_mask( self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int, ) -> torch.BoolTensor: combined_attention_mask = None device = attention_mask.device (_, src_length) = input_shape if src_length > 1: combined_attention_mask = _make_causal_mask_bloom( input_shape, device=device, past_key_values_length=past_key_values_length, ) if bidirectional_mask is not None: assert attention_mask.shape == bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_bloom( bidirectional_mask, tgt_length=src_length ) combined_attention_mask = torch.logical_and( combined_attention_mask, expanded_bidirectional_mask ) expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask ) return combined_attention_mask def _build_alibi_tensor( self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: num_heads = self.config.n_head closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor( 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32, ) powers = torch.arange( 1, 1 + closest_power_of_2, device=device, dtype=torch.int32 ) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor( 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32, ) num_remaining_heads = min( closest_power_of_2, num_heads - closest_power_of_2 ) extra_powers = torch.arange( 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32 ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1) ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1) diffs = qa - ka + key_length - query_length diffs = -diffs.abs() alibi = slopes.view(1, num_heads, 1, 1) * diffs.view( 1, 1, query_length, key_length ) alibi = alibi.expand(batch_size, -1, -1, -1).reshape( -1, query_length, key_length ) return alibi.to(dtype) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[KeyValueT, ...]] = None, attention_mask: Optional[torch.Tensor] = None, bidirectional_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: (batch_size, seq_length) = input_ids.shape elif inputs_embeds is not None: (batch_size, seq_length, _) = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is None: past_key_values = tuple([None] * len(self.h)) head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(inputs_embeds) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: tmp = past_key_values[0][0] past_key_values_length = tmp.shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), device=hidden_states.device ) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor( batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device, ) causal_mask = self._prepare_attn_mask( attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module( *inputs, use_cache=use_cache, output_attentions=output_attentions, ) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i], ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: oa = (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + oa hidden_states = self.ln_f(hidden_states) if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if not return_dict: return tuple( ( v for v in [ hidden_states, presents, all_hidden_states, all_self_attentions, ] if v is not None ) ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) setattr( model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer), ) setattr( model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer), ) setattr(model.transformer, "forward", MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] def forward( self: BloomForCausalLM, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[KeyValueT, ...]] = None, attention_mask: Optional[torch.Tensor] = None, bidirectional_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """Replacement forward method for BloomCausalLM.""" if deprecated_arguments.pop("position_ids", False) is not False: warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() (batch_size, seq_length, vocab_size) = shift_logits.shape loss_fct = CrossEntropyLoss() loss = loss_fct( shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length), ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) def prepare_inputs_for_generation( self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> dict: if past: input_ids = input_ids[:, -1].unsqueeze(-1) bidirectional_mask = None if past[0][0].shape[0] == input_ids.shape[0]: past = self._convert_to_bloom_cache(past) else: bidirectional_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, "past_key_values": past, "use_cache": True, "attention_mask": attention_mask, "bidirectional_mask": bidirectional_mask, } setattr(model, "forward", MethodType(forward, model)) setattr( model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model), ) setattr(model, "_prefix_lm_converted", True) return model def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM: """Converts an OPT Causal LM to a Prefix LM. Supported HuggingFace model classes: - `OPTForCausalLM` See `convert_hf_causal_lm_to_prefix_lm` for more details. """ if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, OPTForCausalLM) assert ( model.config.add_cross_attention == False ), "Only supports OPT decoder-only models" setattr(model, "_original_forward", getattr(model, "forward")) setattr(model, "_original_generate", getattr(model, "generate")) model.model.decoder.bidirectional_mask = None def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): combined_attention_mask = None if input_shape[-1] > 1: if self.bidirectional_mask == "g": (bsz, src_length) = input_shape combined_attention_mask = torch.zeros( (bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device, ) else: combined_attention_mask = _make_causal_mask_opt( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length, ).to(inputs_embeds.device) if self.bidirectional_mask is not None: assert attention_mask.shape == self.bidirectional_mask.shape expanded_bidirectional_mask = _expand_mask_opt( self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], ).to(inputs_embeds.device) combined_attention_mask = torch.maximum( expanded_bidirectional_mask, combined_attention_mask ) if attention_mask is not None: expanded_attn_mask = _expand_mask_opt( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask setattr( model.model.decoder, "_prepare_decoder_attention_mask", MethodType(_prepare_decoder_attention_mask, model.model.decoder), ) def forward( self: OPTForCausalLM, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, bidirectional_mask: Optional[torch.ByteTensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): def call_og_forward(): return self._original_forward( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if bidirectional_mask is None: return call_og_forward() self.model.decoder.bidirectional_mask = bidirectional_mask try: outputs = call_og_forward() except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return outputs def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM-style attention.""" self.model.decoder.bidirectional_mask = "g" try: output = self._original_generate(*args, **kwargs) except: self.model.decoder.bidirectional_mask = None raise self.model.decoder.bidirectional_mask = None return output setattr(model, "forward", MethodType(forward, model)) setattr(model, "generate", MethodType(generate, model)) setattr(model, "_prefix_lm_converted", True) return model _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM) CAUSAL_LM_TYPES = Union[ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM, ] def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES: """Converts a HuggingFace Causal LM to a Prefix LM. Supported HuggingFace model classes: - `GPT2LMHeadModel` - `GPTNeoForCausalLM` - `GPTNeoXForCausalLM` - `GPTJForCausalLM` - `BloomForCausalLM` - `OPTForCausalLM` Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the `generate` method and/or select underlying methods depending on the model class. These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask". Notes on training: To actually train the converted model as a Prefix LM, training batches will need to indicate the prefix/target structure by including `bidirectional_mask` as part of the batch inputs. **This is not a standard input and requires custom layers either within or after your dataloader.** In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels` such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`. That is, the prefix portion of the sequence should not generate any loss. Loss should only be generated by the target portion of the sequence. Notes on `GPTNeoForCausalLM`: To simplify the implementation, "global" and "local" attention layers are handled differently. For "global" layers, we handle conversion as described above. For "local" layers, which use a causal attention mask within a restricted local window, we do not alter the masking. Notes on `forward` method conversion: After conversion, the `forward` method will handle a new input, `bidirectional_mask`, which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions belonging to the prefix (prefix tokens can attend to one another bidirectionally), and 0 indicates token positions belonging to the target. The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset the causal masks before returning the result. Notes on `generate` method conversion: After conversion, the `generate` method will have the same signature but will internally convert all causal masks to be purely bidirectional, call the original `generate` method, and (where appropriate) reset the causal masks before returning the result. This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and previously-generated tokens (also as expected in a Prefix LM). To preserve the API, the original methods are renamed to `_original_forward` and `_original_generate`, and replaced with new `forward` and `generate` methods that wrap them, respectively. Although implementation details vary by model class. """ if isinstance(model, _SUPPORTED_GPT_MODELS): return _convert_gpt_causal_lm_to_prefix_lm(model) elif isinstance(model, BloomForCausalLM): return _convert_bloom_causal_lm_to_prefix_lm(model) elif isinstance(model, OPTForCausalLM): return _convert_opt_causal_lm_to_prefix_lm(model) else: raise TypeError( f"Cannot convert model to Prefix LM. " + f"Model does not belong to set of supported HF models:" + f"\n{_SUPPORTED_HF_MODELS}" ) def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): """Attempts to add bidirectional_mask to batch if missing. Raises: KeyError if bidirectional_mask is missing and can't be inferred """ if "bidirectional_mask" not in batch: if batch.get("mode", None) == "icl_task": batch["bidirectional_mask"] = batch["attention_mask"].clone() for i, continuation_indices in enumerate(batch["continuation_indices"]): batch["bidirectional_mask"][i, continuation_indices] = 0 elif "labels" in batch and "attention_mask" in batch: batch["bidirectional_mask"] = torch.logical_and( torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100) ).type_as(batch["attention_mask"]) else: raise KeyError( "No bidirectional_mask in batch and not sure how to construct one." ) ================================================ FILE: model/llava/model/language_model/mpt/meta_init_context.py ================================================ from contextlib import contextmanager import torch import torch.nn as nn @contextmanager def init_empty_weights(include_buffers: bool = False): """Meta initialization context manager. A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM. Args: include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn # Initialize a model with 100 billions parameters in no time and without using any RAM. with init_empty_weights(): tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) ``` Any model created under this context manager has no weights. As such you can't do something like `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. """ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: yield f @contextmanager def init_on_device(device: torch.device, include_buffers: bool = False): """Device initialization context manager. A context manager under which models are initialized with all parameters on the specified device. Args: device (`torch.device`): Device to initialize all parameters on. include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```python import torch.nn as nn with init_on_device(device=torch.device("cuda")): tst = nn.Liner(100, 100) # on `cuda` device ``` """ old_register_parameter = nn.Module.register_parameter if include_buffers: old_register_buffer = nn.Module.register_buffer def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ module._parameters[name] = param_cls( module._parameters[name].to(device), **kwargs ) def register_empty_buffer(module, name, buffer): old_register_buffer(module, name, buffer) if buffer is not None: module._buffers[name] = module._buffers[name].to(device) if include_buffers: tensor_constructors_to_patch = { torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ["empty", "zeros", "ones", "full"] } else: tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): kwargs["device"] = device return fn(*args, **kwargs) return wrapper try: nn.Module.register_parameter = register_empty_parameter if include_buffers: nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr( torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)), ) yield finally: nn.Module.register_parameter = old_register_parameter if include_buffers: nn.Module.register_buffer = old_register_buffer for ( torch_function_name, old_torch_function, ) in tensor_constructors_to_patch.items(): setattr(torch, torch_function_name, old_torch_function) ================================================ FILE: model/llava/model/language_model/mpt/modeling_mpt.py ================================================ """A simple, flexible implementation of a GPT model. Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import (PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast) from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising from .attention import attn_bias_shape, build_attn_bias from .blocks import MPTBlock from .configuration_mpt import MPTConfig from .custom_embedding import SharedEmbedding from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm) from .meta_init_context import init_empty_weights from .norm import NORM_CLASS_REGISTRY from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ try: from .flash_attn_triton import flash_attn_func except: pass Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = "model" _no_split_modules = ["MPTBlock"] class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): config._validate_config() super().__init__(config) self.attn_impl = config.attn_config["attn_impl"] self.prefix_lm = config.attn_config["prefix_lm"] self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] self.alibi = config.attn_config["alibi"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] if config.init_device == "mixed": if dist.get_local_rank() == 0: config.init_device = "cpu" else: config.init_device = "meta" if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) raise NotImplementedError( f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." ) norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] self.embedding_fraction = config.embedding_fraction self.wte = SharedEmbedding( config.vocab_size, config.d_model, device=config.init_device ) if not self.alibi: self.wpe = torch.nn.Embedding( config.max_seq_len, config.d_model, device=config.init_device ) self.emb_drop = nn.Dropout(config.emb_pdrop) self.blocks = nn.ModuleList( [ MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers) ] ) self.norm_f = norm_class(config.d_model, device=config.init_device) if config.init_device != "meta": print( f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.' ) self.apply(self.param_init_fn) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None self.attn_bias_shape = attn_bias_shape( self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id, ) if config.no_bias: for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): if config.verbose: warnings.warn(f"Removing bias ({module.bias}) from {module}.") module.register_parameter("bias", None) if config.verbose and config.verbose > 2: print(self) if "verbose" not in self.config.init_config: self.config.init_config["verbose"] = self.config.verbose if self.config.init_config["verbose"] > 1: init_fn_name = self.config.init_config["name"] warnings.warn(f"Using {init_fn_name} initialization.") self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value @torch.no_grad() def _attn_bias( self, device, dtype, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, ): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros( self.attn_bias_shape, device=device, dtype=dtype ) self.attn_bias = build_attn_bias( self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max, ) self._attn_bias_initialized = True if self.attn_impl == "flash": return (self.attn_bias, attention_mask) if self.attn_bias is not None: self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, torch.Tensor) assert isinstance(prefix_mask, torch.Tensor) attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: s_k = attention_mask.shape[-1] if attn_bias is None: attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) else: _s_k = max(0, attn_bias.size(-1) - s_k) attn_bias = attn_bias[:, :, :, _s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError( f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal." ) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill( ~attention_mask.view(-1, 1, 1, s_k), min_val ) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): (s_k, s_q) = attn_bias.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError( "attn_bias does not match the expected shape. " + f"The last two dimensions should both be {self.config.max_length} " + f"but are {s_k} and {s_q}." ) seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" ) attn_bias = attn_bias[..., :seq_len, :seq_len] causal = torch.tril( torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) ).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def _apply_sequence_id( self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor ): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" ) attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not( torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) ).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.Tensor] = None, ): return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: attention_mask = attention_mask.bool() if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: raise NotImplementedError( "return_dict False is not implemented yet for MPT" ) if output_attentions: if self.attn_impl != "torch": raise NotImplementedError( "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." ) if ( attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training ): raise NotImplementedError( "MPT does not support training with left padding." ) if self.prefix_lm and prefix_mask is None: raise ValueError( "prefix_mask is a required argument when MPT is configured with prefix_lm=True." ) if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode." ) elif self.attn_uses_sequence_id is False and sequence_id is not None: warnings.warn( "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." ) if input_ids is not None: S = input_ids.size(1) assert ( S <= self.config.max_seq_len ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" tok_emb = self.wte(input_ids) else: assert inputs_embeds is not None assert ( self.alibi ), "inputs_embeds is not implemented for MPT unless for alibi." S = inputs_embeds.size(1) tok_emb = inputs_embeds if self.alibi: x = tok_emb else: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( f"past_key_values must provide a past_key_value for each attention " + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." ) past_position = past_key_values[0][0].size(1) if self.attn_impl == "torch": past_position = past_key_values[0][0].size(3) if S + past_position > self.config.max_seq_len: raise ValueError( f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." ) pos = torch.arange( past_position, S + past_position, dtype=torch.long, device=input_ids.device, ).unsqueeze(0) if attention_mask is not None: pos = torch.clamp( pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ :, past_position: ], min=0, ) pos_emb = self.wpe(pos) x = tok_emb + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) else: x_shrunk = x * self.embedding_fraction + x.detach() * ( 1 - self.embedding_fraction ) assert isinstance(self.emb_drop, nn.Module) x = self.emb_drop(x_shrunk) (attn_bias, attention_mask) = self._attn_bias( device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, ) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) past_key_value = ( past_key_values[b_idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint( block, x, past_key_value, attn_bias, attention_mask, self.is_causal ) else: (x, attn_weights, past_key_value) = block( x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, ) if past_key_values is not None: past_key_values[b_idx] = past_key_value if output_attentions: assert all_self_attns is not None all_self_attns = all_self_attns + (attn_weights,) x = self.norm_f(x) if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) return BaseModelOutputWithPast( last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) def param_init_fn(self, module): init_fn_name = self.config.init_config["name"] MODEL_INIT_REGISTRY[init_fn_name]( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config, ) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") print(f"Instantiating an MPTForCausalLM model from {__file__}") self.transformer = MPTModel(config) for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue if isinstance(child, torch.nn.Module): child._fsdp_wrap = True self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == "inv_sqrt_d_model": logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError( f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." ) self.logit_scale = logit_scale def get_input_embeddings(self): return self.transformer.wte def set_input_embeddings(self, value): self.transformer.wte = value def get_output_embeddings(self): return self.transformer.wte def set_output_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings def set_decoder(self, decoder): self.transformer = decoder def get_decoder(self): return self.transformer def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache if inputs_embeds is not None: raise NotImplementedError( "inputs_embeds has to be None (for hf/peft support)." ) outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, ) logits = self.transformer.wte( outputs.last_hidden_state.to(self.transformer.wte.weight.device), True ) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." ) logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def param_init_fn(self, module): init_fn_name = self.config.init_config["name"] MODEL_INIT_REGISTRY[init_fn_name]( module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config, ) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): if inputs_embeds is not None: raise NotImplementedError("inputs_embeds is not implemented for MPT yet") attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( "MPT does not support generation with right padding." ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get("use_cache") == False: raise NotImplementedError( "MPT with prefix_lm=True does not support use_cache=False." ) else: prefix_mask = None return { "input_ids": input_ids, "attention_mask": attention_mask, "prefix_mask": prefix_mask, "sequence_id": sequence_id, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), } @staticmethod def _reorder_cache(past_key_values, beam_idx): """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 for an example in transformers. """ reordered_past = [] for layer_past in past_key_values: reordered_past += [ tuple( (past_state.index_select(0, beam_idx) for past_state in layer_past) ) ] return reordered_past ================================================ FILE: model/llava/model/language_model/mpt/norm.py ================================================ import torch def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == "cuda": dtype = torch.get_autocast_gpu_dtype() elif tensor.device.type == "cpu": dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor class LPLayerNorm(torch.nn.LayerNorm): def __init__( self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None, ): super().__init__( normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype, ) def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) downcast_bias = ( _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias ) with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm( downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps, ) def rms_norm(x, weight=None, eps=1e-05): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight return output class RMSNorm(torch.nn.Module): def __init__( self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None ): super().__init__() self.eps = eps if weight: self.weight = torch.nn.Parameter( torch.ones(normalized_shape, dtype=dtype, device=device) ) else: self.register_parameter("weight", None) def forward(self, x): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) class LPRMSNorm(RMSNorm): def __init__( self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None ): super().__init__( normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device, ) def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) NORM_CLASS_REGISTRY = { "layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm, } ================================================ FILE: model/llava/model/language_model/mpt/param_init_fns.py ================================================ import math import warnings from collections.abc import Sequence from functools import partial from typing import Optional, Tuple, Union import torch from torch import nn from .norm import NORM_CLASS_REGISTRY def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs): del kwargs if verbose > 1: warnings.warn(f"Initializing network using module's reset_parameters attribute") if hasattr(module, "reset_parameters"): module.reset_parameters() def fused_init_helper_(module: nn.Module, init_fn_): _fused = getattr(module, "_fused", None) if _fused is None: raise RuntimeError(f"Internal logic error") (dim, splits) = _fused splits = (0, *splits, module.weight.size(dim)) for s, e in zip(splits[:-1], splits[1:]): slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) init_fn_(module.weight[slice_indices]) def generic_param_init_fn_( module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, **kwargs, ): del kwargs if verbose > 1: warnings.warn(f"If model has bias parameters they are initialized to 0.") init_div_is_residual = init_div_is_residual if init_div_is_residual is False: div_is_residual = 1.0 elif init_div_is_residual is True: div_is_residual = math.sqrt(2 * n_layers) elif isinstance(init_div_is_residual, float) or isinstance( init_div_is_residual, int ): div_is_residual = init_div_is_residual elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): div_is_residual = float(init_div_is_residual) else: div_is_residual = 1.0 raise ValueError( f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}" ) if init_div_is_residual is not False: if verbose > 1: warnings.warn( f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. " + f"Set `init_div_is_residual: false` in init config to disable this." ) if isinstance(module, nn.Linear): if hasattr(module, "_fused"): fused_init_helper_(module, init_fn_) else: init_fn_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) if init_div_is_residual is not False and getattr(module, "_is_residual", False): with torch.no_grad(): module.weight.div_(div_is_residual) elif isinstance(module, nn.Embedding): if emb_init_std is not None: std = emb_init_std if std == 0: warnings.warn(f"Embedding layer initialized to 0.") emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) if verbose > 1: warnings.warn( f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}." ) elif emb_init_uniform_lim is not None: lim = emb_init_uniform_lim if isinstance(lim, Sequence): if len(lim) > 2: raise ValueError( f"Uniform init requires a min and a max limit. User input: {lim}." ) if lim[0] == lim[1]: warnings.warn(f"Embedding layer initialized to {lim[0]}.") else: if lim == 0: warnings.warn(f"Embedding layer initialized to 0.") lim = [-lim, lim] (a, b) = lim emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) if verbose > 1: warnings.warn( f"Embedding layer initialized using uniform distribution in range {lim}." ) else: emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): if verbose > 1: warnings.warn( f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0." ) if hasattr(module, "weight") and module.weight is not None: torch.nn.init.ones_(module.weight) if hasattr(module, "bias") and module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): if module._qkv_same_embed_dim: assert module.in_proj_weight is not None assert ( module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) ) assert d_model is not None _d = d_model splits = (0, _d, 2 * _d, 3 * _d) for s, e in zip(splits[:-1], splits[1:]): init_fn_(module.in_proj_weight[s:e]) else: assert ( module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) ) assert module.in_proj_weight is None init_fn_(module.q_proj_weight) init_fn_(module.k_proj_weight) init_fn_(module.v_proj_weight) if module.in_proj_bias is not None: torch.nn.init.zeros_(module.in_proj_bias) if module.bias_k is not None: torch.nn.init.zeros_(module.bias_k) if module.bias_v is not None: torch.nn.init.zeros_(module.bias_v) init_fn_(module.out_proj.weight) if init_div_is_residual is not False and getattr( module.out_proj, "_is_residual", False ): with torch.no_grad(): module.out_proj.weight.div_(div_is_residual) if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) else: for _ in module.parameters(recurse=False): raise NotImplementedError( f"{module.__class__.__name__} parameters are not initialized by param_init_fn." ) def _normal_init_(std, mean=0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) def _normal_param_init_fn_( module: nn.Module, std: float, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, **kwargs, ): del kwargs init_fn_ = _normal_init_(std=std) if verbose > 1: warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}") generic_param_init_fn_( module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def baseline_param_init_fn_( module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, **kwargs, ): del kwargs if init_std is None: raise ValueError( "You must set model.init_config['init_std'] to a float value to use the default initialization scheme." ) _normal_param_init_fn_( module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def small_param_init_fn_( module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, **kwargs, ): del kwargs std = math.sqrt(2 / (5 * d_model)) _normal_param_init_fn_( module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def neox_param_init_fn_( module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, **kwargs, ): """From section 2.3.1 of GPT-NeoX-20B: An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py """ del kwargs residual_div = n_layers / math.sqrt(10) if verbose > 1: warnings.warn(f"setting init_div_is_residual to {residual_div}") small_param_init_fn_( module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def kaiming_uniform_param_init_fn_( module: nn.Module, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, fan_mode: str = "fan_in", init_nonlinearity: str = "leaky_relu", verbose: int = 0, **kwargs, ): del kwargs if verbose > 1: warnings.warn( f"Using nn.init.kaiming_uniform_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}" ) kaiming_uniform_ = partial( nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity, ) generic_param_init_fn_( module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def kaiming_normal_param_init_fn_( module: nn.Module, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, fan_mode: str = "fan_in", init_nonlinearity: str = "leaky_relu", verbose: int = 0, **kwargs, ): del kwargs if verbose > 1: warnings.warn( f"Using nn.init.kaiming_normal_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}" ) kaiming_normal_ = partial( torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity, ) generic_param_init_fn_( module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def xavier_uniform_param_init_fn_( module: nn.Module, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, verbose: int = 0, **kwargs, ): del kwargs xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) if verbose > 1: warnings.warn( f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}" ) generic_param_init_fn_( module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) def xavier_normal_param_init_fn_( module: nn.Module, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, verbose: int = 0, **kwargs, ): xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) if verbose > 1: warnings.warn( f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}" ) generic_param_init_fn_( module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose, ) MODEL_INIT_REGISTRY = { "default_": torch_default_param_init_fn_, "baseline_": baseline_param_init_fn_, "kaiming_uniform_": kaiming_uniform_param_init_fn_, "kaiming_normal_": kaiming_normal_param_init_fn_, "neox_init_": neox_param_init_fn_, "small_init_": small_param_init_fn_, "xavier_uniform_": xavier_uniform_param_init_fn_, "xavier_normal_": xavier_normal_param_init_fn_, } ================================================ FILE: model/llava/model/llava_arch.py ================================================ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn as nn # from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX) from .multimodal_encoder.builder import build_vision_tower class LlavaMetaModel: def __init__(self, config): super(LlavaMetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter self.config.mm_vision_tower = vision_tower vision_tower = build_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower self.config.use_mm_proj = True self.config.mm_hidden_size = vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature if not hasattr(self, "mm_projector"): self.mm_projector = nn.Linear( self.config.mm_hidden_size, self.config.hidden_size ) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load( pretrain_mm_mlp_adapter, map_location="cpu" ) def get_w(weights, keyword): return { k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k } self.mm_projector.load_state_dict( get_w(mm_projector_weights, "mm_projector") ) class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if ( past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1 ): attention_mask = torch.ones( (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device, ) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = ( cur_input_embeds + ( 0.0 * self.get_model().mm_projector(vision_tower.dummy_feature) ).sum() ) new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( self.config, "mm_use_im_start_end", False ): cur_new_input_embeds.append( self.get_model() .embed_tokens(cur_input_ids[: image_token_start - 1]) .detach() ) cur_new_input_embeds.append( self.get_model().embed_tokens( cur_input_ids[image_token_start - 1 : image_token_start] ) ) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append( self.get_model().embed_tokens( cur_input_ids[image_token_start + 1 : image_token_start + 2] ) ) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype, ) ) cur_new_labels.append( cur_labels[image_token_start : image_token_start + 1] ) cur_labels = cur_labels[image_token_start + 2 :] elif getattr(self.config, "mm_use_im_start_end", False): cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids[:image_token_start]) ) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append( self.get_model().embed_tokens( cur_input_ids[image_token_start + 1 : image_token_start + 2] ) ) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype, ) ) cur_new_labels.append( cur_labels[image_token_start + 1 : image_token_start + 2] ) cur_labels = cur_labels[image_token_start + 2 :] else: cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids[:image_token_start]) ) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype, ) ) cur_labels = cur_labels[image_token_start + 1 :] cur_image_idx += 1 if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( self.config, "mm_use_im_start_end", False ): cur_input_ids = cur_input_ids[image_token_start + 2 :] elif getattr(self.config, "mm_use_im_start_end", False): cur_input_ids = cur_input_ids[image_token_start + 2 :] else: cur_input_ids = cur_input_ids[image_token_start + 1 :] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( self.config, "mm_use_im_start_end", False ): cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids).detach() ) elif getattr(self.config, "mm_use_im_start_end", False): cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids) ) else: cur_new_input_embeds.append( self.get_model().embed_tokens(cur_input_ids) ) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [ x.to(device=self.device) for x in cur_new_input_embeds ] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat( ( cur_new_label, torch.full( (max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device, ), ), dim=0, ) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip( attention_mask, _new_labels, new_labels ): new_attn_mask_pad_left = torch.full( (cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device, ) new_attn_mask_pad_right = torch.full( (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device, ) cur_new_attention_mask = torch.cat( ( new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right, ), dim=0, ) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( ( attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1], ), True, dtype=attention_mask.dtype, device=attention_mask.device, ) attention_mask = torch.cat( (new_attn_mask_pad_left, attention_mask), dim=1 ) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels # def initialize_vision_tokenizer(self, model_args, tokenizer): def initialize_vision_tokenizer(self, model_args, num_new_tokens): # if model_args.mm_use_im_patch_token: # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) # self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) # self.resize_token_embeddings(len(tokenizer)) # if num_new_tokens > 0: # input_embeddings = self.get_input_embeddings().weight.data # output_embeddings = self.get_output_embeddings().weight.data # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( # dim=0, keepdim=True) # output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( # dim=0, keepdim=True) # input_embeddings[-num_new_tokens:] = input_embeddings_avg # output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load( model_args.pretrain_mm_mlp_adapter, map_location="cpu" ) embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[ -num_new_tokens: ] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError( f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." ) elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False ================================================ FILE: model/llava/model/make_delta.py ================================================ """ Usage: python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta """ import argparse import torch from llava.model.utils import auto_upgrade from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) print("Loading target model") auto_upgrade(target_model_path) target = AutoModelForCausalLM.from_pretrained( target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) print("Calculating delta") for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): if name not in base.state_dict(): assert name in [ "model.mm_projector.weight", "model.mm_projector.bias", ], f"{name} not in base model" continue if param.data.shape == base.state_dict()[name].shape: param.data -= base.state_dict()[name] else: assert name in [ "model.embed_tokens.weight", "lm_head.weight", ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" bparam = base.state_dict()[name] param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam print("Saving delta") if hub_repo_id: kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} else: kwargs = {} target.save_pretrained(delta_path, **kwargs) target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) target_tokenizer.save_pretrained(delta_path, **kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--target-model-path", type=str, required=True) parser.add_argument("--delta-path", type=str, required=True) parser.add_argument("--hub-repo-id", type=str, default=None) args = parser.parse_args() make_delta( args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id ) ================================================ FILE: model/llava/model/multimodal_encoder/builder.py ================================================ from .clip_encoder import CLIPVisionTower def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr( vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None), ) if ( vision_tower.startswith("openai") or vision_tower.startswith("laion") or "clip" in vision_tower ): return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f"Unknown vision tower: {vision_tower}") ================================================ FILE: model/llava/model/multimodal_encoder/clip_encoder.py ================================================ import torch import torch.nn as nn from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, "mm_vision_select_feature", "patch") if not delay_load: self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) def load_model(self): self.image_processor = CLIPImageProcessor.from_pretrained( self.vision_tower_name ) self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower_name, low_cpu_mem_usage=True ) self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower( image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True, ) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower( images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, ) image_features = self.feature_select(image_forward_outs).to(images.dtype) torch.cuda.empty_cache() return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 ================================================ FILE: model/llava/model/utils.py ================================================ from transformers import AutoConfig def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if "llava" in config and "llava" not in cfg.model_type: assert cfg.model_type == "llama" print( "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." ) print( "You must upgrade the checkpoint to the new code base (this can be done automatically)." ) confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") cfg.architectures[0] = "LlavaLlamaForCausalLM" cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) ================================================ FILE: model/llava/train/llama_flash_attn_monkey_patch.py ================================================ import logging from typing import List, Optional, Tuple import torch import transformers from einops import rearrange from torch import nn from transformers.models.llama.modeling_llama import apply_rotary_pos_emb try: from flash_attn.flash_attn_interface import \ flash_attn_unpadded_qkvpacked_func except ImportError: from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func, ) from flash_attn.bert_padding import pad_input, unpad_input def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel attention_mask: [bsz, q_len] """ bsz, q_len, _ = hidden_states.size() query_states = ( self.q_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( self.k_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(hidden_states) .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] assert past_key_value is None, "past_key_value is not supported" cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" # Flash attention codes from # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # transform the data into the format required by flash attention qkv = torch.stack( [query_states, key_states, value_states], dim=2 ) # [bsz, nh, 3, q_len, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask = attention_mask if key_padding_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len cu_q_lens = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads ) output_unpad = flash_attn_unpadded_qkvpacked_func( x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len ), "b s (h d) -> b s h d", h=nheads, ) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): cuda_major, cuda_minor = torch.cuda.get_device_capability() if cuda_major < 8: logging.warning( "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward ================================================ FILE: model/llava/train/llava_trainer.py ================================================ import os from typing import Optional import torch from transformers import Trainer def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, "no ignore status") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = { k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match) } to_return = { k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items() } return to_return class LLaVATrainer(Trainer): def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, "tune_mm_mlp_adapter", False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ["mm_projector"] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3( self.model.named_parameters(), keys_to_match ) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save( weight_to_save, os.path.join(output_dir, f"mm_projector.bin") ) else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, "tune_mm_mlp_adapter", False): pass else: super(LLaVATrainer, self)._save(output_dir, state_dict) ================================================ FILE: model/llava/train/train.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import json import logging import os import pathlib from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence import torch import transformers from llava import conversation as conversation_lib from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX) from llava.mm_utils import tokenizer_image_token from llava.model import * from llava.train.llava_trainer import LLaVATrainer from PIL import Image from torch.utils.data import Dataset local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field( default=-1 ) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_vision_select_feature: Optional[str] = field(default="patch") @dataclass class DataArguments: data_path: str = field( default=None, metadata={"help": "Path to the training data."} ) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = "square" image_grid_pinpoints: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={ "help": "Compress the quantization statistics through double quantization." }, ) quant_type: str = field( default="nf4", metadata={ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." }, ) bits: int = field(default=16, metadata={"help": "How many bits to use."}) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning( f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" ) with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = { k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match) } to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if "lm_head" in lora_module_names: # needed for 16-bit lora_module_names.remove("lm_head") return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if getattr(trainer.args, "tune_mm_mlp_adapter", False): # Only save Adapter keys_to_match = ["mm_projector"] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3( trainer.model.named_parameters(), keys_to_match ) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save( weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"), ) else: torch.save( weight_to_save, os.path.join(output_dir, f"mm_projector.bin") ) return if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def _tokenize_fn( strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer ) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = "unknown" sentence["value"] = ( BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL ) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence["value"]: sentence["value"] = ( sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() ) sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] sentence["value"] = sentence["value"].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "", ) replace_token = DEFAULT_IMAGE_TOKEN if data_args.mm_use_im_start_end: replace_token = ( DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN ) sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, replace_token ) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append( conv.sep.join(rounds[conv_idx : conv_idx + 2]) ) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len( tokenizer_image_token(conv.sep, tokenizer) ) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]["value"] source[0]["value"] = DEFAULT_IMAGE_TOKEN conversation = ( source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep ) conversations.append(conversation) # tokenize conversations input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if ( conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN ): return preprocess_plain(sources, tokenizer) if ( conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 ): return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn( [header] + [s["value"] for s in source], tokenizer )["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, ): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) rank0_print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if "image" in sources[0]: image_file = self.list_data_dict[i]["image"] image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") if self.data_args.image_aspect_ratio == "pad": def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new( pil_img.mode, (width, width), background_color ) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new( pil_img.mode, (height, height), background_color ) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square( image, tuple(int(x * 255) for x in processor.image_mean) ) image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] else: image = processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args ) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]) ) if isinstance(i, int): data_dict = dict( input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0] ) # image exist in the data if "image" in self.list_data_dict[i]: data_dict["image"] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"]) return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple( [instance[key] for instance in instances] for key in ("input_ids", "labels") ) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX ) input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if "image" in instances[0]: images = [instance["image"] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch["images"] = torch.stack(images) else: batch["images"] = images return batch def make_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args ) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset( tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args ) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict( train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator ) def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments) ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = ( torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) ) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update( dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} ), ) ) if model_args.vision_tower is not None: if "mpt" in model_args.model_name_or_path: config = transformers.AutoConfig.from_pretrained( model_args.model_name_or_path, trust_remote_code=True ) config.attn_config["attn_impl"] = training_args.mpt_attn_impl model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args, ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args, ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args, ) model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype = ( torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) ) model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=training_args.gradient_checkpointing ) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) if "mpt" in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token="[PAD]"), tokenizer=tokenizer, model=model, ) elif model_args.version == "v0.5": tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.unk_token if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[ model_args.version ] else: conversation_lib.default_conversation = conversation_lib.conv_templates[ "vicuna_v1" ] if model_args.vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=training_args.fsdp ) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints model.config.tune_mm_mlp_adapter = ( training_args.tune_mm_mlp_adapter ) = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().mm_projector.to( dtype=compute_dtype, device=training_args.device ) model.config.mm_use_im_start_end = ( data_args.mm_use_im_start_end ) = model_args.mm_use_im_start_end training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if "norm" in name: module = module.to(torch.float32) if "lm_head" in name or "embed_tokens" in name: if hasattr(module, "weight"): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = LLaVATrainer( model=model, tokenizer=tokenizer, args=training_args, **data_module ) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save( non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"), ) else: safe_save_model_for_hf_trainer( trainer=trainer, output_dir=training_args.output_dir ) if __name__ == "__main__": train() ================================================ FILE: model/llava/train/train_mem.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. from llava.train.llama_flash_attn_monkey_patch import \ replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() from llava.train.train import train if __name__ == "__main__": train() ================================================ FILE: model/llava/utils.py ================================================ import datetime import logging import logging.handlers import os import sys import requests from llava.constants import LOGDIR server_error_msg = ( "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" ) moderation_msg = ( "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." ) handler = None def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when="D", utc=True ) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = "" def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = "" for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == "\n": self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != "": self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = "" def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = { "Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], } text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" ================================================ FILE: model/segment_anything/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .automatic_mask_generator import SamAutomaticMaskGenerator from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry) from .predictor import SamPredictor ================================================ FILE: model/segment_anything/automatic_mask_generator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore from .modeling import Sam from .predictor import SamPredictor from .utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points) class SamAutomaticMaskGenerator: def __init__( self, model: Sam, points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", ) -> None: """ Using a SAM model, generates masks for the entire image. Generates a grid of point prompts over the image, then filters low quality and duplicate masks. The default settings are chosen for SAM with a ViT-H backbone. Arguments: model (Sam): The SAM model to use for mask prediction. points_per_side (int or None): The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling. points_per_batch (int): Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. pred_iou_thresh (float): A filtering threshold in [0,1], using the model's predicted mask quality. stability_score_thresh (float): A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions. stability_score_offset (float): The amount to shift the cutoff when calculated the stability score. box_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks. crop_n_layers (int): If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops. crop_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks between different crops. crop_overlap_ratio (float): Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. crop_n_points_downscale_factor (int): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. point_grids (list(np.ndarray) or None): A list over explicit grids of points used for sampling, normalized to [0,1]. The nth grid in the list is used in the nth crop layer. Exclusive with points_per_side. min_mask_region_area (int): If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv. output_mode (str): The form masks are returned in. Can be 'binary_mask', 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. """ assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." if points_per_side is not None: self.point_grids = build_all_layer_point_grids( points_per_side, crop_n_layers, crop_n_points_downscale_factor, ) elif point_grids is not None: self.point_grids = point_grids else: raise ValueError("Can't have both points_per_side and point_grid be None.") assert output_mode in [ "binary_mask", "uncompressed_rle", "coco_rle", ], f"Unknown output_mode {output_mode}." if output_mode == "coco_rle": from pycocotools import \ mask as mask_utils # type: ignore # noqa: F401 if min_mask_region_area > 0: import cv2 # type: ignore # noqa: F401 self.predictor = SamPredictor(model) self.points_per_batch = points_per_batch self.pred_iou_thresh = pred_iou_thresh self.stability_score_thresh = stability_score_thresh self.stability_score_offset = stability_score_offset self.box_nms_thresh = box_nms_thresh self.crop_n_layers = crop_n_layers self.crop_nms_thresh = crop_nms_thresh self.crop_overlap_ratio = crop_overlap_ratio self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: """ Generates masks for the given image. Arguments: image (np.ndarray): The image to generate masks for, in HWC uint8 format. Returns: list(dict(str, any)): A list over records for masks. Each record is a dict containing the following keys: segmentation (dict(str, any) or np.ndarray): The mask. If output_mode='binary_mask', is an array of shape HW. Otherwise, is a dictionary containing the RLE. bbox (list(float)): The box around the mask, in XYWH format. area (int): The area in pixels of the mask. predicted_iou (float): The model's own prediction of the mask's quality. This is filtered by the pred_iou_thresh parameter. point_coords (list(list(float))): The point coordinates input to the model to generate this mask. stability_score (float): A measure of the mask's quality. This is filtered on using the stability_score_thresh parameter. crop_box (list(float)): The crop of the image used to generate the mask, given in XYWH format. """ # Generate masks mask_data = self._generate_masks(image) # Filter small disconnected regions and holes in masks if self.min_mask_region_area > 0: mask_data = self.postprocess_small_regions( mask_data, self.min_mask_region_area, max(self.box_nms_thresh, self.crop_nms_thresh), ) # Encode masks if self.output_mode == "coco_rle": mask_data["segmentations"] = [ coco_encode_rle(rle) for rle in mask_data["rles"] ] elif self.output_mode == "binary_mask": mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] else: mask_data["segmentations"] = mask_data["rles"] # Write mask records curr_anns = [] for idx in range(len(mask_data["segmentations"])): ann = { "segmentation": mask_data["segmentations"][idx], "area": area_from_rle(mask_data["rles"][idx]), "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), "predicted_iou": mask_data["iou_preds"][idx].item(), "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), } curr_anns.append(ann) return curr_anns def _generate_masks(self, image: np.ndarray) -> MaskData: orig_size = image.shape[:2] crop_boxes, layer_idxs = generate_crop_boxes( orig_size, self.crop_n_layers, self.crop_overlap_ratio ) # Iterate over image crops data = MaskData() for crop_box, layer_idx in zip(crop_boxes, layer_idxs): crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) data.cat(crop_data) # Remove duplicate masks between crops if len(crop_boxes) > 1: # Prefer masks from smaller crops scores = 1 / box_area(data["crop_boxes"]) scores = scores.to(data["boxes"].device) keep_by_nms = batched_nms( data["boxes"].float(), scores, torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.crop_nms_thresh, ) data.filter(keep_by_nms) data.to_numpy() return data def _process_crop( self, image: np.ndarray, crop_box: List[int], crop_layer_idx: int, orig_size: Tuple[int, ...], ) -> MaskData: # Crop the image and calculate embeddings x0, y0, x1, y1 = crop_box cropped_im = image[y0:y1, x0:x1, :] cropped_im_size = cropped_im.shape[:2] self.predictor.set_image(cropped_im) # Get points for this crop points_scale = np.array(cropped_im_size)[None, ::-1] points_for_image = self.point_grids[crop_layer_idx] * points_scale # Generate masks for this crop in batches data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): batch_data = self._process_batch( points, cropped_im_size, crop_box, orig_size ) data.cat(batch_data) del batch_data self.predictor.reset_image() # Remove duplicates within this crop. keep_by_nms = batched_nms( data["boxes"].float(), data["iou_preds"], torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.box_nms_thresh, ) data.filter(keep_by_nms) # Return to the original image frame data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) data["points"] = uncrop_points(data["points"], crop_box) data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) return data def _process_batch( self, points: np.ndarray, im_size: Tuple[int, ...], crop_box: List[int], orig_size: Tuple[int, ...], ) -> MaskData: orig_h, orig_w = orig_size # Run model on this batch transformed_points = self.predictor.transform.apply_coords(points, im_size) in_points = torch.as_tensor(transformed_points, device=self.predictor.device) in_labels = torch.ones( in_points.shape[0], dtype=torch.int, device=in_points.device ) masks, iou_preds, _ = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], multimask_output=True, return_logits=True, ) # Serialize predictions and store in MaskData data = MaskData( masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1), points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), ) del masks # Filter by predicted IoU if self.pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > self.pred_iou_thresh data.filter(keep_mask) # Calculate stability score data["stability_score"] = calculate_stability_score( data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset, ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) # Threshold masks and calculate boxes data["masks"] = data["masks"] > self.predictor.model.mask_threshold data["boxes"] = batched_mask_to_box(data["masks"]) # Filter boxes that touch crop boundaries keep_mask = ~is_box_near_crop_edge( data["boxes"], crop_box, [0, 0, orig_w, orig_h] ) if not torch.all(keep_mask): data.filter(keep_mask) # Compress to RLE data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) data["rles"] = mask_to_rle_pytorch(data["masks"]) del data["masks"] return data @staticmethod def postprocess_small_regions( mask_data: MaskData, min_area: int, nms_thresh: float ) -> MaskData: """ Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates. Edits mask_data in place. Requires open-cv as a dependency. """ if len(mask_data["rles"]) == 0: return mask_data # Filter small disconnected regions and holes new_masks = [] scores = [] for rle in mask_data["rles"]: mask = rle_to_mask(rle) mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) # Give score=0 to changed masks and score=1 to unchanged masks # so NMS will prefer ones that didn't need postprocessing scores.append(float(unchanged)) # Recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores), torch.zeros_like(boxes[:, 0]), # categories iou_threshold=nms_thresh, ) # Only recalculate RLEs for masks that have changed for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly mask_data.filter(keep_by_nms) return mask_data ================================================ FILE: model/segment_anything/build_sam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from functools import partial import torch from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer) def build_sam_vit_h(checkpoint=None): return _build_sam( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, ) build_sam = build_sam_vit_h def build_sam_vit_l(checkpoint=None): return _build_sam( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, ) def build_sam_vit_b(checkpoint=None): return _build_sam( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, ) sam_model_registry = { "default": build_sam_vit_h, "vit_h": build_sam_vit_h, "vit_l": build_sam_vit_l, "vit_b": build_sam_vit_b, } def _build_sam( encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, ): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size sam = Sam( image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) sam.load_state_dict(state_dict, strict=False) return sam ================================================ FILE: model/segment_anything/modeling/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder from .sam import Sam from .transformer import TwoWayTransformer ================================================ FILE: model/segment_anything/modeling/common.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Type import torch import torch.nn as nn class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x ================================================ FILE: model/segment_anything/modeling/image_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F from .common import LayerNorm2d, MLPBlock # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.embed_dim = embed_dim self.out_chans = out_chans self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros( 1, img_size // patch_size, img_size // patch_size, embed_dim ) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) dtype = x.dtype if dtype == torch.float16: # prevent overflow with torch.autocast(device_type="cuda", dtype=torch.float32): x = self.neck(x.permute(0, 3, 1, 2)) x = x.to(dtype) else: x = self.neck(x.permute(0, 3, 1, 2)) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock( embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer ) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = ( self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) ) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos( attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) ) attn = attn.softmax(dim=-1) x = ( (attn @ v) .view(B, self.num_heads, H, W, -1) .permute(0, 2, 3, 1, 4) .reshape(B, H, W, -1) ) x = self.proj(x) return x def window_partition( x: torch.Tensor, window_size: int ) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = ( x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) ) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int], ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x ================================================ FILE: model/segment_anything/modeling/mask_decoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Tuple, Type import torch from torch import nn from torch.nn import functional as F from .common import LayerNorm2d class MaskDecoder(nn.Module): def __init__( self, *, transformer_dim: int, transformer: nn.Module, num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, ) -> None: """ Predicts masks given an image and prompt embeddings, using a transformer architecture. Arguments: transformer_dim (int): the channel dimension of the transformer transformer (nn.Module): the transformer used to predict masks num_multimask_outputs (int): the number of masks to predict when disambiguating masks activation (nn.Module): the type of activation to use when upscaling masks iou_head_depth (int): the depth of the MLP used to predict mask quality iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality """ super().__init__() self.transformer_dim = transformer_dim self.transformer = transformer self.num_multimask_outputs = num_multimask_outputs self.iou_token = nn.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.output_upscaling = nn.Sequential( nn.ConvTranspose2d( transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 ), LayerNorm2d(transformer_dim // 4), activation(), nn.ConvTranspose2d( transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 ), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens) ] ) self.iou_prediction_head = MLP( transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth ) def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, ) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) else: mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] # Prepare output return masks, iou_pred def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens output_tokens = torch.cat( [self.iou_token.weight, self.mask_tokens.weight], dim=0 ) output_tokens = output_tokens.unsqueeze(0).expand( sparse_prompt_embeddings.size(0), -1, -1 ) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # image_embeddings: [1, C, H, W], tokens: [B, N, C] # dense_prompt_embeddings: [B, C, H, W] # Expand per-image data in batch direction to be per-mask src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append( self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) ) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view( b, self.num_mask_tokens, h, w ) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred def forward_modified_v3( self, image_embeddings : torch.Tensor, # [b, 256, 64, 64] image_pe : torch.Tensor, # [1, 256, 64, 64] sparse_prompt_embeddings: torch.Tensor, # [b, 1, 256] dense_prompt_embeddings : torch.Tensor, # [b, 256, 64, 64] ): # -> [b, 256, 256] b, _, _, _ = image_embeddings.shape # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # [num_mask + 1, 256] output_tokens = output_tokens.unsqueeze(0).expand(b, -1, -1) # [b, num_mask + 1, 256] tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [b, num_mask + 1 + 1, 256] # image_embeddings: [1, C, H, W], tokens: [B, N, C] # dense_prompt_embeddings: [B, C, H, W] # Expand per-image data in batch direction to be per-mask src = image_embeddings src = src + dense_prompt_embeddings # [b, 256, 64, 64] pos_src = torch.repeat_interleave(image_pe, b, dim=0) # [b, 256, 64, 64] _, c, h, w = src.shape # Run the transformer hs, src = self.transformer(src, pos_src, tokens) mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] num_mask = 1 for i in range(num_mask): # we only need the first mask hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, h, w) # [b, 256, 256] return masks # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x ================================================ FILE: model/segment_anything/modeling/prompt_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Optional, Tuple, Type import numpy as np import torch from torch import nn from .common import LayerNorm2d class PromptEncoder(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners point_embeddings = [ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) ] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = ( 4 * image_embedding_size[0], 4 * image_embedding_size[1], ) self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans // 4), activation(), nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans), activation(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords( coords, self.input_image_size ) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], text_embeds: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] elif text_embeds is not None: return text_embeds.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], text_embeds: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks, text_embeds) sparse_embeddings = torch.empty( (bs, 0, self.embed_dim), device=self._get_device() ) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if text_embeds is not None: sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) return sparse_embeddings, dense_embeddings class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 if coords.dtype != self.positional_encoding_gaussian_matrix.dtype: coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w = size device: Any = self.positional_encoding_gaussian_matrix.device grid = torch.ones( (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype ) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / h x_embed = x_embed / w pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) return pe.permute(2, 0, 1) # C x H x W def forward_with_coords( self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[1] coords[:, :, 1] = coords[:, :, 1] / image_size[0] return self._pe_encoding(coords.to(torch.float)) # B x N x C ================================================ FILE: model/segment_anything/modeling/sam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List, Tuple import torch from torch import nn from torch.nn import functional as F from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer( "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False ) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device @torch.no_grad() def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool, ) -> List[Dict[str, torch.Tensor]]: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present. 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already transformed to the input frame of the model. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of the model. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single mask. Returns: (list(dict)): A list over input images, where each element is as dictionary with the following keys. 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ input_images = torch.stack( [self.preprocess(x["image"]) for x in batched_input], dim=0 ) image_embeddings = self.image_encoder(input_images) outputs = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) masks = self.postprocess_masks( low_res_masks, input_size=image_record["image"].shape[-2:], original_size=image_record["original_size"], ) masks = masks > self.mask_threshold outputs.append( { "masks": masks, "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } ) return outputs def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ dtype = masks.dtype masks = F.interpolate( masks.float(), (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) # masks = masks.to(dtype) masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate( masks, original_size, mode="bilinear", align_corners=False ) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x ================================================ FILE: model/segment_anything/modeling/transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Tuple, Type import torch from torch import Tensor, nn from .common import MLPBlock class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert ( self.internal_dim % num_heads == 0 ), "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) # Get output out = attn @ v out = self._recombine_heads(out) out = self.out_proj(out) return out ================================================ FILE: model/segment_anything/predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple import numpy as np import torch from .modeling import Sam from .utils.transforms import ResizeLongestSide class SamPredictor: def __init__( self, sam_model: Sam, ) -> None: """ Uses SAM to calculate the image embedding for an image, and then allow repeated, efficient mask prediction given prompts. Arguments: sam_model (Sam): The model to use for mask prediction. """ super().__init__() self.model = sam_model self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) self.reset_image() def set_image( self, image: np.ndarray, image_format: str = "RGB", ) -> None: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Arguments: image (np.ndarray): The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255]. image_format (str): The color format of the image, in ['RGB', 'BGR']. """ assert image_format in [ "RGB", "BGR", ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." if image_format != self.model.image_format: image = image[..., ::-1] # Transform the image to the form expected by the model input_image = self.transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ None, :, :, : ] self.set_torch_image(input_image_torch, image.shape[:2]) @torch.no_grad() def set_torch_image( self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...], ) -> None: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Expects the input image to be already transformed to the format expected by the model. Arguments: transformed_image (torch.Tensor): The input image, with shape 1x3xHxW, which has been transformed with ResizeLongestSide. original_image_size (tuple(int, int)): The size of the image before transformation, in (H, W) format. """ assert ( len(transformed_image.shape) == 4 and transformed_image.shape[1] == 3 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." self.reset_image() self.original_size = original_image_size self.input_size = tuple(transformed_image.shape[-2:]) input_image = self.model.preprocess(transformed_image) self.features = self.model.image_encoder(input_image) self.is_image_set = True def predict( self, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. Arguments: point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (np.ndarray): The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ if not self.is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) # Transform input prompts coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None if point_coords is not None: assert ( point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = self.transform.apply_coords(point_coords, self.original_size) coords_torch = torch.as_tensor( point_coords, dtype=torch.float, device=self.device ) labels_torch = torch.as_tensor( point_labels, dtype=torch.int, device=self.device ) coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, self.original_size) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = box_torch[None, :] if mask_input is not None: mask_input_torch = torch.as_tensor( mask_input, dtype=torch.float, device=self.device ) mask_input_torch = mask_input_torch[None, :, :, :] masks, iou_predictions, low_res_masks = self.predict_torch( coords_torch, labels_torch, box_torch, mask_input_torch, multimask_output, return_logits=return_logits, ) masks_np = masks[0].detach().cpu().numpy() iou_predictions_np = iou_predictions[0].detach().cpu().numpy() low_res_masks_np = low_res_masks[0].detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad() def predict_torch( self, point_coords: Optional[torch.Tensor], point_labels: Optional[torch.Tensor], boxes: Optional[torch.Tensor] = None, mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, return_logits: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. Input prompts are batched torch tensors and are expected to already be transformed to the input frame using ResizeLongestSide. Arguments: point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (torch.Tensor or None): A BxN array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the number of masks, and (H, W) is the original image size. (torch.Tensor): An array of shape BxC containing the model's predictions for the quality of each mask. (torch.Tensor): An array of shape BxCxHxW, where C is the number of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ if not self.is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) if point_coords is not None: points = (point_coords, point_labels) else: points = None # Embed prompts sparse_embeddings, dense_embeddings = self.model.prompt_encoder( points=points, boxes=boxes, masks=mask_input, ) # Predict masks low_res_masks, iou_predictions = self.model.mask_decoder( image_embeddings=self.features, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) # Upscale the masks to the original image resolution masks = self.model.postprocess_masks( low_res_masks, self.input_size, self.original_size ) if not return_logits: masks = masks > self.model.mask_threshold return masks, iou_predictions, low_res_masks def get_image_embedding(self) -> torch.Tensor: """ Returns the image embeddings for the currently set image, with shape 1xCxHxW, where C is the embedding dimension and (H,W) are the embedding spatial dimension of SAM (typically C=256, H=W=64). """ if not self.is_image_set: raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) assert ( self.features is not None ), "Features must exist if an image has been set." return self.features @property def device(self) -> torch.device: return self.model.device def reset_image(self) -> None: """Resets the currently set image.""" self.is_image_set = False self.features = None self.orig_h = None self.orig_w = None self.input_h = None self.input_w = None ================================================ FILE: model/segment_anything/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: model/segment_anything/utils/amg.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple import numpy as np import torch class MaskData: """ A structure for storing masks and their related data in batched format. Implements basic filtering and concatenation. """ def __init__(self, **kwargs) -> None: for v in kwargs.values(): assert isinstance( v, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, key: str, item: Any) -> None: assert isinstance( item, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats[key] = item def __delitem__(self, key: str) -> None: del self._stats[key] def __getitem__(self, key: str) -> Any: return self._stats[key] def items(self) -> ItemsView[str, Any]: return self._stats.items() def filter(self, keep: torch.Tensor) -> None: for k, v in self._stats.items(): if v is None: self._stats[k] = None elif isinstance(v, torch.Tensor): self._stats[k] = v[torch.as_tensor(keep, device=v.device)] elif isinstance(v, np.ndarray): self._stats[k] = v[keep.detach().cpu().numpy()] elif isinstance(v, list) and keep.dtype == torch.bool: self._stats[k] = [a for i, a in enumerate(v) if keep[i]] elif isinstance(v, list): self._stats[k] = [v[i] for i in keep] else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def cat(self, new_stats: "MaskData") -> None: for k, v in new_stats.items(): if k not in self._stats or self._stats[k] is None: self._stats[k] = deepcopy(v) elif isinstance(v, torch.Tensor): self._stats[k] = torch.cat([self._stats[k], v], dim=0) elif isinstance(v, np.ndarray): self._stats[k] = np.concatenate([self._stats[k], v], axis=0) elif isinstance(v, list): self._stats[k] = self._stats[k] + deepcopy(v) else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def to_numpy(self) -> None: for k, v in self._stats.items(): if isinstance(v, torch.Tensor): self._stats[k] = v.detach().cpu().numpy() def is_box_near_crop_edge( boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 ) -> torch.Tensor: """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) boxes = uncrop_boxes_xyxy(boxes, crop_box).float() near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) return torch.any(near_crop_edge, dim=1) def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: box_xywh = deepcopy(box_xyxy) box_xywh[2] = box_xywh[2] - box_xywh[0] box_xywh[3] = box_xywh[3] - box_xywh[1] return box_xywh def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: assert len(args) > 0 and all( len(a) == len(args[0]) for a in args ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: """ Encodes masks to an uncompressed RLE, in the format expected by pycoco tools. """ # Put in fortran order and flatten h,w b, h, w = tensor.shape tensor = tensor.permute(0, 2, 1).flatten(1) # Compute change indices diff = tensor[:, 1:] ^ tensor[:, :-1] change_indices = diff.nonzero() # Encode run length out = [] for i in range(b): cur_idxs = change_indices[change_indices[:, 0] == i, 1] cur_idxs = torch.cat( [ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), cur_idxs + 1, torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ] ) btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if tensor[i, 0] == 0 else [0] counts.extend(btw_idxs.detach().cpu().tolist()) out.append({"size": [h, w], "counts": counts}) return out def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" h, w = rle["size"] mask = np.empty(h * w, dtype=bool) idx = 0 parity = False for count in rle["counts"]: mask[idx : idx + count] = parity idx += count parity ^= True mask = mask.reshape(w, h) return mask.transpose() # Put in C order def area_from_rle(rle: Dict[str, Any]) -> int: return sum(rle["counts"][1::2]) def calculate_stability_score( masks: torch.Tensor, mask_threshold: float, threshold_offset: float ) -> torch.Tensor: """ Computes the stability score for a batch of masks. The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high and low values. """ # One mask is always contained inside the other. # Save memory by preventing unnecessary cast to torch.int64 intersections = ( (masks > (mask_threshold + threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) unions = ( (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) return intersections / unions def build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) points_y = np.tile(points_one_side[:, None], (1, n_per_side)) points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) return points def build_all_layer_point_grids( n_per_side: int, n_layers: int, scale_per_layer: int ) -> List[np.ndarray]: """Generates point grids for all crop layers.""" points_by_layer = [] for i in range(n_layers + 1): n_points = int(n_per_side / (scale_per_layer**i)) points_by_layer.append(build_point_grid(n_points)) return points_by_layer def generate_crop_boxes( im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float ) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. """ crop_boxes, layer_idxs = [], [] im_h, im_w = im_size short_side = min(im_h, im_w) # Original image crop_boxes.append([0, 0, im_w, im_h]) layer_idxs.append(0) def crop_len(orig_len, n_crops, overlap): return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) for i_layer in range(n_layers): n_crops_per_side = 2 ** (i_layer + 1) overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) crop_w = crop_len(im_w, n_crops_per_side, overlap) crop_h = crop_len(im_h, n_crops_per_side, overlap) crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] # Crops in XYWH format for x0, y0 in product(crop_box_x0, crop_box_y0): box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] crop_boxes.append(box) layer_idxs.append(i_layer + 1) return crop_boxes, layer_idxs def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) # Check if boxes has a channel dimension if len(boxes.shape) == 3: offset = offset.unsqueeze(1) return boxes + offset def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0]], device=points.device) # Check if points has a channel dimension if len(points.shape) == 3: offset = offset.unsqueeze(1) return points + offset def uncrop_masks( masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int ) -> torch.Tensor: x0, y0, x1, y1 = crop_box if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: return masks # Coordinate transform masks pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) pad = (x0, pad_x - x0, y0, pad_y - y0) return torch.nn.functional.pad(masks, pad, value=0) def remove_small_regions( mask: np.ndarray, area_thresh: float, mode: str ) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. """ import cv2 # type: ignore assert mode in ["holes", "islands"] correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] if len(small_regions) == 0: return mask, False fill_labels = [0] + small_regions if not correct_holes: fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest if len(fill_labels) == 0: fill_labels = [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: from pycocotools import mask as mask_utils # type: ignore h, w = uncompressed_rle["size"] rle = mask_utils.frPyObjects(uncompressed_rle, h, w) rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json return rle def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: """ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. """ # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: return torch.zeros(*masks.shape[:-2], 4, device=masks.device) # Normalize shape to CxHxW shape = masks.shape h, w = shape[-2:] if len(shape) > 2: masks = masks.flatten(0, -3) else: masks = masks.unsqueeze(0) # Get top and bottom edges in_height, _ = torch.max(masks, dim=-1) in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] bottom_edges, _ = torch.max(in_height_coords, dim=-1) in_height_coords = in_height_coords + h * (~in_height) top_edges, _ = torch.min(in_height_coords, dim=-1) # Get left and right edges in_width, _ = torch.max(masks, dim=-2) in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] right_edges, _ = torch.max(in_width_coords, dim=-1) in_width_coords = in_width_coords + w * (~in_width) left_edges, _ = torch.min(in_width_coords, dim=-1) # If the mask is empty the right edge will be to the left of the left edge. # Replace these boxes with [0, 0, 0, 0] empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) out = out * (~empty_filter).unsqueeze(-1) # Return to original shape if len(shape) > 2: out = out.reshape(*shape[:-2], 4) else: out = out[0] return out ================================================ FILE: model/segment_anything/utils/onnx.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Tuple import torch import torch.nn as nn from torch.nn import functional as F from ..modeling import Sam from .amg import calculate_stability_score class SamOnnxModel(nn.Module): """ This model should not be called directly, but is used in ONNX export. It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, with some functions modified to enable model tracing. Also supports extra options controlling what information. See the ONNX export script for details. """ def __init__( self, model: Sam, return_single_mask: bool, use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: super().__init__() self.mask_decoder = model.mask_decoder self.model = model self.img_size = model.image_encoder.img_size self.return_single_mask = return_single_mask self.use_stability_score = use_stability_score self.stability_score_offset = 1.0 self.return_extra_metrics = return_extra_metrics @staticmethod def resize_longest_image_size( input_image_size: torch.Tensor, longest_side: int ) -> torch.Tensor: input_image_size = input_image_size.to(torch.float32) scale = longest_side / torch.max(input_image_size) transformed_size = scale * input_image_size transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) return transformed_size def _embed_points( self, point_coords: torch.Tensor, point_labels: torch.Tensor ) -> torch.Tensor: point_coords = point_coords + 0.5 point_coords = point_coords / self.img_size point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_embedding = point_embedding * (point_labels != -1) point_embedding = ( point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) ) for i in range(self.model.prompt_encoder.num_point_embeddings): point_embedding = ( point_embedding + self.model.prompt_encoder.point_embeddings[i].weight * (point_labels == i) ) return point_embedding def _embed_masks( self, input_mask: torch.Tensor, has_mask_input: torch.Tensor ) -> torch.Tensor: mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( input_mask ) mask_embedding = mask_embedding + ( 1 - has_mask_input ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) return mask_embedding def mask_postprocessing( self, masks: torch.Tensor, orig_im_size: torch.Tensor ) -> torch.Tensor: masks = F.interpolate( masks, size=(self.img_size, self.img_size), mode="bilinear", align_corners=False, ) prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( torch.int64 ) masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) h, w = orig_im_size[0], orig_im_size[1] masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) return masks def select_masks( self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Determine if we should return the multiclick mask or not from the number of points. # The reweighting is used to avoid control flow. score_reweight = torch.tensor( [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] ).to(iou_preds.device) score = iou_preds + (num_points - 2.5) * score_reweight best_idx = torch.argmax(score, dim=1) masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) return masks, iou_preds @torch.no_grad() def forward( self, image_embeddings: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, orig_im_size: torch.Tensor, ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) masks, scores = self.model.mask_decoder.predict_masks( image_embeddings=image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding, ) if self.use_stability_score: scores = calculate_stability_score( masks, self.model.mask_threshold, self.stability_score_offset ) if self.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) upscaled_masks = self.mask_postprocessing(masks, orig_im_size) if self.return_extra_metrics: stability_scores = calculate_stability_score( upscaled_masks, self.model.mask_threshold, self.stability_score_offset ) areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) return upscaled_masks, scores, stability_scores, areas, masks return upscaled_masks, scores, masks ================================================ FILE: model/segment_anything/utils/transforms.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from copy import deepcopy from typing import Tuple import numpy as np import torch from torch.nn import functional as F from torchvision.transforms.functional import resize # type: ignore from torchvision.transforms.functional import to_pil_image class ResizeLongestSide: """ Resizes images to the longest side 'target_length', as well as provides methods for resizing coordinates and boxes. Provides methods for transforming both numpy array and batched torch tensors. """ def __init__(self, target_length: int) -> None: self.target_length = target_length def apply_image(self, image: np.ndarray) -> np.ndarray: """ Expects a numpy array with shape HxWxC in uint8 format. """ target_size = self.get_preprocess_shape( image.shape[0], image.shape[1], self.target_length ) return np.array(resize(to_pil_image(image), target_size)) def apply_coords( self, coords: np.ndarray, original_size: Tuple[int, ...] ) -> np.ndarray: """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size new_h, new_w = self.get_preprocess_shape( original_size[0], original_size[1], self.target_length ) coords = deepcopy(coords).astype(float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords def apply_boxes( self, boxes: np.ndarray, original_size: Tuple[int, ...] ) -> np.ndarray: """ Expects a numpy array shape Bx4. Requires the original image size in (H, W) format. """ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) return boxes.reshape(-1, 4) def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: """ Expects batched images with shape BxCxHxW and float format. This transformation may not exactly match apply_image. apply_image is the transformation expected by the model. """ # Expects an image in BCHW format. May not exactly match apply_image. target_size = self.get_preprocess_shape( image.shape[0], image.shape[1], self.target_length ) return F.interpolate( image, target_size, mode="bilinear", align_corners=False, antialias=True ) def apply_coords_torch( self, coords: torch.Tensor, original_size: Tuple[int, ...] ) -> torch.Tensor: """ Expects a torch tensor with length 2 in the last dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size new_h, new_w = self.get_preprocess_shape( original_size[0], original_size[1], self.target_length ) coords = deepcopy(coords).to(torch.float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords def apply_boxes_torch( self, boxes: torch.Tensor, original_size: Tuple[int, ...] ) -> torch.Tensor: """ Expects a torch tensor with shape Bx4. Requires the original image size in (H, W) format. """ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) return boxes.reshape(-1, 4) @staticmethod def get_preprocess_shape( oldh: int, oldw: int, long_side_length: int ) -> Tuple[int, int]: """ Compute the output size given input size and target long side length. """ scale = long_side_length * 1.0 / max(oldh, oldw) newh, neww = oldh * scale, oldw * scale neww = int(neww + 0.5) newh = int(newh + 0.5) return (newh, neww) ================================================ FILE: model/tf/modeling_outputs.py ================================================ import torch import warnings from dataclasses import dataclass from typing import Optional, Tuple, Dict, List from transformers.utils import ModelOutput @dataclass class CausalLMOutputWithPastAndLabel(ModelOutput): """ Base class for causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, returned when `labels` is provided): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None labels: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None bs2imgs_token_list: List[List[int]] = None ================================================ FILE: model/univi/__init__.py ================================================ from .model import ChatUniViLlamaForCausalLM ================================================ FILE: model/univi/config/__init__.py ================================================ from .dataset_config import * from .model_config import * ModelConfig = { "PRETUNE": model_config_pretune, "FINETUNE": model_config_finetune, } DataConfig = { "Pretrain": [Pretrain, COCO_CAP, COCO_REG, COCO_REC], "SQA": [SQA], "FINETUNE": [VIT, MIMIC_imageonly, VIDEO], } ================================================ FILE: model/univi/config/dataset_config.py ================================================ Pretrain = { "chat_path": "${PATH}/CC3M-595K/chat.json", "CC3M": "${PATH}/CC3M-595K", } VIT = { "chat_path": "${PATH}/llava_instruct_150k.json", "COCO2017": "${PATH}/COCO2017/train2017", } MIMIC_imageonly = { "chat_path": "${PATH}/MIMIC-IT-imageonly.json", "CDG": "${PATH}/CGD/images", "LA": "${PATH}/LA/images", "SD": "${PATH}/SD/images", } COCO_CAP = { "chat_path": "${PATH}/COCO/coco_cap_chat.json", "COCO2014": "${PATH}/COCO2014/train2014", } COCO_REG = { "chat_path": "${PATH}/COCO/coco_reg_chat.json", "COCO2014": "${PATH}/COCO2014/train2014", } COCO_REC = { "chat_path": "${PATH}/COCO/coco_rec_chat.json", "COCO2014": "${PATH}/COCO2014/train2014", } VIDEO = { "chat_path": "${PATH}/video_chat.json", "VIDEO": "${PATH}/Activity_Videos", } SQA = { "chat_path": "${PATH}/llava_train_QCM-LEA.json", "ScienceQA": "${PATH}/scienceqa/train", } ================================================ FILE: model/univi/config/model_config.py ================================================ model_config_pretune = { "use_cluster": True, "freeze": False, "vision_tune": False, "spatial_cluster_rate0": 64, # 0.25 "spatial_cluster_rate1": 32, # 0.5 "spatial_cluster_rate2": 16, # 0.5 "temporal_cluster_rate": 1/16, } model_config_finetune = { "use_cluster": True, "freeze": False, "mm_tune": True, "vision_tune": False, "spatial_cluster_rate0": 64, # 0.25 "spatial_cluster_rate1": 32, # 0.5 "spatial_cluster_rate2": 16, # 0.5 "temporal_cluster_rate": 1/16, } ================================================ FILE: model/univi/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants MAX_IMAGE_LENGTH = 64 IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_VIDEO_TOKEN = "