Copy disabled (too large)
Download .txt
Showing preview only (25,443K chars total). Download the full file to get everything.
Repository: cilinyan/VISA
Branch: main
Commit: ec158fde9cdf
Files: 276
Total size: 54.3 MB
Directory structure:
gitextract_qarvn3ca/
├── .gitignore
├── .gitmodules
├── README.md
├── XMem/
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── range_transform.py
│ │ ├── reseed.py
│ │ ├── static_dataset.py
│ │ ├── tps.py
│ │ ├── util.py
│ │ └── vos_dataset.py
│ ├── eval.py
│ ├── eval_batch.py
│ ├── generate_xmem_data_single.py
│ ├── inference/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── mask_mapper.py
│ │ │ ├── test_datasets.py
│ │ │ └── video_reader.py
│ │ ├── inference_core.py
│ │ ├── interact/
│ │ │ ├── __init__.py
│ │ │ ├── fbrs/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ ├── controller.py
│ │ │ │ ├── inference/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── clicker.py
│ │ │ │ │ ├── evaluation.py
│ │ │ │ │ ├── predictors/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── base.py
│ │ │ │ │ │ ├── brs.py
│ │ │ │ │ │ ├── brs_functors.py
│ │ │ │ │ │ └── brs_losses.py
│ │ │ │ │ ├── transforms/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── base.py
│ │ │ │ │ │ ├── crops.py
│ │ │ │ │ │ ├── flip.py
│ │ │ │ │ │ ├── limit_longest_side.py
│ │ │ │ │ │ └── zoom_in.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── model/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── initializer.py
│ │ │ │ │ ├── is_deeplab_model.py
│ │ │ │ │ ├── is_hrnet_model.py
│ │ │ │ │ ├── losses.py
│ │ │ │ │ ├── metrics.py
│ │ │ │ │ ├── modeling/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── basic_blocks.py
│ │ │ │ │ │ ├── deeplab_v3.py
│ │ │ │ │ │ ├── hrnet_ocr.py
│ │ │ │ │ │ ├── ocr.py
│ │ │ │ │ │ ├── resnet.py
│ │ │ │ │ │ └── resnetv1b.py
│ │ │ │ │ ├── ops.py
│ │ │ │ │ └── syncbn/
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── modules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── functional/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── _csrc.py
│ │ │ │ │ │ ├── csrc/
│ │ │ │ │ │ │ ├── bn.h
│ │ │ │ │ │ │ ├── cuda/
│ │ │ │ │ │ │ │ ├── bn_cuda.cu
│ │ │ │ │ │ │ │ ├── common.h
│ │ │ │ │ │ │ │ └── ext_lib.h
│ │ │ │ │ │ │ └── ext_lib.cpp
│ │ │ │ │ │ └── syncbn.py
│ │ │ │ │ └── nn/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── syncbn.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cython/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── _get_dist_maps.pyx
│ │ │ │ │ ├── _get_dist_maps.pyxbld
│ │ │ │ │ └── dist_maps.py
│ │ │ │ ├── misc.py
│ │ │ │ └── vis.py
│ │ │ ├── fbrs_controller.py
│ │ │ ├── gui.py
│ │ │ ├── gui_utils.py
│ │ │ ├── interaction.py
│ │ │ ├── interactive_utils.py
│ │ │ ├── resource_manager.py
│ │ │ ├── s2m/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _deeplab.py
│ │ │ │ ├── s2m_network.py
│ │ │ │ ├── s2m_resnet.py
│ │ │ │ └── utils.py
│ │ │ ├── s2m_controller.py
│ │ │ └── timer.py
│ │ ├── kv_memory_store.py
│ │ └── memory_manager.py
│ ├── interactive_demo.py
│ ├── merge_multi_scale.py
│ ├── merge_results.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── aggregate.py
│ │ ├── cbam.py
│ │ ├── group_modules.py
│ │ ├── losses.py
│ │ ├── memory_util.py
│ │ ├── modules.py
│ │ ├── network.py
│ │ ├── resnet.py
│ │ └── trainer.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── download_bl30k.py
│ │ ├── download_datasets.py
│ │ ├── download_models.sh
│ │ ├── download_models_demo.sh
│ │ ├── expand_long_vid.py
│ │ └── resize_youtube.py
│ ├── tracking.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── configuration.py
│ ├── davis_subset.txt
│ ├── image_saver.py
│ ├── load_subset.py
│ ├── log_integrator.py
│ ├── logger.py
│ ├── palette.py
│ ├── tensor_util.py
│ └── yv_subset.txt
├── merge_lora_weights_and_save_hf_model.py
├── model/
│ ├── VISA.py
│ ├── llava/
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── mm_utils.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── apply_delta.py
│ │ │ ├── builder.py
│ │ │ ├── consolidate.py
│ │ │ ├── language_model/
│ │ │ │ ├── llava_llama.py
│ │ │ │ ├── llava_mpt.py
│ │ │ │ └── mpt/
│ │ │ │ ├── adapt_tokenizer.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── blocks.py
│ │ │ │ ├── configuration_mpt.py
│ │ │ │ ├── custom_embedding.py
│ │ │ │ ├── flash_attn_triton.py
│ │ │ │ ├── hf_prefixlm_converter.py
│ │ │ │ ├── meta_init_context.py
│ │ │ │ ├── modeling_mpt.py
│ │ │ │ ├── norm.py
│ │ │ │ └── param_init_fns.py
│ │ │ ├── llava_arch.py
│ │ │ ├── make_delta.py
│ │ │ ├── multimodal_encoder/
│ │ │ │ ├── builder.py
│ │ │ │ └── clip_encoder.py
│ │ │ └── utils.py
│ │ ├── train/
│ │ │ ├── llama_flash_attn_monkey_patch.py
│ │ │ ├── llava_trainer.py
│ │ │ ├── train.py
│ │ │ └── train_mem.py
│ │ └── utils.py
│ ├── segment_anything/
│ │ ├── __init__.py
│ │ ├── automatic_mask_generator.py
│ │ ├── build_sam.py
│ │ ├── modeling/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── image_encoder.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── onnx.py
│ │ └── transforms.py
│ ├── tf/
│ │ └── modeling_outputs.py
│ └── univi/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── dataset_config.py
│ │ └── model_config.py
│ ├── constants.py
│ ├── conversation.py
│ ├── demo.py
│ ├── eval/
│ │ ├── evaluate/
│ │ │ ├── evaluate_benchmark_1_correctness.py
│ │ │ ├── evaluate_benchmark_2_detailed_orientation.py
│ │ │ ├── evaluate_benchmark_3_context.py
│ │ │ ├── evaluate_benchmark_4_temporal.py
│ │ │ ├── evaluate_benchmark_5_consistency.py
│ │ │ ├── evaluate_gpt_review_visual.py
│ │ │ ├── evaluate_science_qa.py
│ │ │ ├── evaluate_video_qa.py
│ │ │ └── summarize_gpt_review.py
│ │ ├── model_coco_vqa.py
│ │ ├── model_video_consistency.py
│ │ ├── model_video_general.py
│ │ ├── model_video_qa.py
│ │ ├── model_vqa.py
│ │ ├── model_vqa_scienceqa.py
│ │ ├── questions/
│ │ │ ├── coco2014_val_qa_eval/
│ │ │ │ ├── qa90_gpt4_answer.jsonl
│ │ │ │ └── qa90_questions.jsonl
│ │ │ ├── coco_pope/
│ │ │ │ ├── coco_pope_adversarial.jsonl
│ │ │ │ ├── coco_pope_popular.jsonl
│ │ │ │ └── coco_pope_random.jsonl
│ │ │ ├── scienceqa/
│ │ │ │ ├── pid_splits.json
│ │ │ │ ├── problems.json
│ │ │ │ └── test_QCM-LEA.json
│ │ │ └── video_qa/
│ │ │ ├── activitynet_a_list.json
│ │ │ ├── activitynet_qa.json
│ │ │ ├── consistency_qa.json
│ │ │ ├── generic_qa.json
│ │ │ ├── msrvtt_a_list.json
│ │ │ ├── msrvtt_qa.json
│ │ │ ├── msvd_a_list.json
│ │ │ ├── msvd_qa.json
│ │ │ ├── temporal_qa.json
│ │ │ ├── tgif_a_list.json
│ │ │ └── tgif_qa.json
│ │ └── table/
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── question.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── arch.py
│ │ ├── builder.py
│ │ ├── cluster.py
│ │ ├── consolidate.py
│ │ ├── dataloader.py
│ │ ├── language_model/
│ │ │ └── llama.py
│ │ ├── make_delta.py
│ │ └── multimodal_encoder/
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ ├── eva_encoder.py
│ │ ├── eva_vit.py
│ │ ├── processor.py
│ │ └── utils.py
│ ├── train/
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── train.py
│ │ ├── train_mem.py
│ │ └── trainer.py
│ └── utils.py
├── requirements.txt
├── scripts/
│ ├── train_13b.sh
│ ├── train_7b.sh
│ └── val_7b_video.sh
├── tools/
│ ├── eval_davis17.py
│ ├── eval_mevis.py
│ ├── eval_revos.py
│ ├── generate_foreground_mask.py
│ ├── metrics.py
│ ├── zip_mp_mevis.py
│ └── zip_mp_refytvos.py
├── train_ds.py
├── utils/
│ ├── ade20k_classes.json
│ ├── chatunivi_dataset.py
│ ├── cocostuff_classes.txt
│ ├── conversation.py
│ ├── d2_datasets/
│ │ ├── categories.py
│ │ ├── mevis_utils.py
│ │ ├── refytvos_utils.py
│ │ ├── refytvos_val_videos.py
│ │ └── ytvis_api/
│ │ ├── __init__.py
│ │ ├── ytvos.py
│ │ └── ytvoseval.py
│ ├── data_processing.py
│ ├── dataset.py
│ ├── dataset_config.py
│ ├── grefcoco.py
│ ├── grefer.py
│ ├── random_list.py
│ ├── reason_seg_dataset.py
│ ├── refer.py
│ ├── refer_seg_dataset.py
│ ├── rvos_dataset.py
│ ├── rvos_eval_dataset.py
│ ├── sem_seg_dataset.py
│ ├── utils.py
│ └── vqa_dataset.py
└── utils_llamavid/
├── llamavid_client.py
└── llamavid_server.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
**/__pycache__
runs/
models/
datasets/
datasets
.vscode/
core*
vis_output/
test_vis/
openai
.DS_Store
XMem/weights/
================================================
FILE: .gitmodules
================================================
[submodule "LLaMA-VID"]
path = LLaMA-VID
url = git@github.com:dvlab-research/LLaMA-VID.git
================================================
FILE: README.md
================================================
# VISA: Reasoning Video Object Segmentation via Large Language Model
<font size=7><div align='center' >
[](https://github.com/cilinyan/VISA)
[](http://arxiv.org/abs/2407.11325)
[](https://github.com/cilinyan/ReVOS-api)
</div></font>
<div align=center>
<img src="assert/architecture.png" style="width:100%;">
</div>
## 🚀 Performance
<div style="text-align: justify;">
VISA demonstrates remarkable proficiency in handling complex segmentation tasks that require: (a) reasoning based on world knowledge; (b) inference of future events; and (c) a comprehensive understanding of video content.
</div>
<div align=center>
<img src="assert/performance.png" style="width:50%;">
</div>
## 🛠️ Installation
```shell
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
```
## 🦄 Training and Validation
### 1. Training Data Preparation
Before training, please download the datasets, and then configure the path in [dataset_config.py](utils/dataset_config.py).
<details open>
<summary> <strong>LISA's Dataset</strong> </summary>
Follow [LISA](https://github.com/dvlab-research/LISA/tree/main) to prepare LISA's datasets. The dataset folder should be stored in the `$LISA_ROOT` folder.
```
LISA_ROOT
├── ade20k
├── coco
├── cocostuff
├── llava_dataset
├── mapillary
├── reason_seg
├── refer_seg
└── vlpart
```
</details>
<details open>
<summary> <strong>Chat-UniVi's Dataset</strong> </summary>
Follow [Chat-UniVi/Chat-UniVi-Instruct](https://huggingface.co/datasets/Chat-UniVi/Chat-UniVi-Instruct/tree/main) to prepare `Chat-UniVi-Instruct` datasets. The dataset folder should be stored in the `$ChatUniVi_ROOT` folder.
```
ChatUniVi_ROOT
├── Fine-tuning
│ ├── MIMIC_imageonly
│ └── VIDEO
└── ScienceQA_tuning
```
</details>
<details open>
<summary> <strong>RVOS's Dataset</strong> </summary>
1. Reasoning Video Segmentation Datasets: [ReVOS](https://github.com/cilinyan/ReVOS-api).
2. Referring Video Segmentation Datasets: [Ref-Youtube-VOS](https://github.com/wjn922/ReferFormer/blob/main/docs/data.md), [Ref-DAVIS17](https://github.com/wjn922/ReferFormer/blob/main/docs/data.md), [MeViS](https://github.com/henghuiding/MeViS).
- Ref-Youtube-VOS: Download `mask_dict.pkl` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/EqR9g3yWG5pPtVoil0EfsbgBJhCZ7YwaRG9w9GsYy1_N5g?e=JLaJfc) or [BaiduPan](https://pan.baidu.com/s/1mbJaDDy0UTlA7sysp0zypg?pwd=visa).
- Ref-DAVIS17: Download `mask_dict.pkl` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/Eq8bmGqNcYxGhQ1bioN65q4B_gPxIabpJUjGaV5uqcaq3w?e=2J6Ldp) or [BaiduPan](https://pan.baidu.com/s/1Gg5qPvxRZMKDp0JrVRJ75w?pwd=visa).
3. Open-Vocabulary Video Instance Segmentation Dataset: [LV-VIS](https://github.com/haochenheheda/LVVIS/tree/main).
Download `mask_dict.json` and `meta_expressions.json` from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:f:/g/personal/yancl9918_mails_jlu_edu_cn/EttXAjMV8yFJhHMQwX3mIw0BP7dymKV-cuw4uAotDaAwYw?e=j6Y44X) or [BaiduPan](https://pan.baidu.com/s/1LOWPnuxXF_LXGSL7osRptA?pwd=visa). Then, put the annotations files in the `$RVOS_ROOT/lvvis/train` directory as follows.
```
RVOS_ROOT
├── ReVOS
│ ├── JPEGImages
│ ├── mask_dict.json
│ ├── mask_dict_foreground.json
│ ├── meta_expressions_train_.json
│ └── meta_expressions_valid_.json
├── lvvis
│ └── train
| ├── JPEGImages
| ├── mask_dict.json
| └── meta_expressions.json
├── Ref-Youtube-VOS
│ ├── meta_expressions
| | ├── train/meta_expressions.json
| | └── valid/meta_expressions.json
│ ├── train
| | ├── JPEGImages
| | └── mask_dict.pkl
│ └── valid
| └── JPEGImages
├── davis17
│ ├── meta_expressions
| | ├── train/meta_expressions.json
| | └── valid/meta_expressions.json
│ ├── train
| | ├── JPEGImages
| | └── mask_dict.pkl
│ └── valid
| ├── JPEGImages
| └── mask_dict.pkl
└── mevis
```
</details>
### 2. Pre-trained weights
<details open>
<summary> <strong>Chat-UniVi</strong> </summary>
To train VISA-7B or 13B, you need to download Chat-UniVi weights from [Chat-UniVi-7B](https://huggingface.co/Chat-UniVi/Chat-UniVi) and [Chat-UniVi-13B](https://huggingface.co/Chat-UniVi/Chat-UniVi-13B).
</details>
<details open>
<summary> <strong>SAM</strong> </summary>
Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
</details>
### 3. Training VISA
```shell
# Training VISA-7B
bash scripts/train_7b.sh
# Extracting fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints.
cd /PATH/TO/VISA-7B/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
# Merge LoRA Weight
CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
--version Chat-UniVi/Chat-UniVi \
--weight /PATH/TO/VISA-7B/pytorch_model.bin \
--save_path /PATH/TO/VISA-7B/hf_model
```
### 4. Validation
<details open>
<summary> <strong>1. Using `VISA` to generate predicted mask of each video <a href="https://github.com/cilinyan/VISA/blob/main/scripts/val_7b_video.sh">[demo]</a></strong> </summary>
```shell
deepspeed --master_port=24999 train_ds.py \
--version="/PATH/TO/VISA-7B/hf_model" \
--vision_pretrained="/PATH/TO/sam_vit_h_4b8939.pth" \
--log_base_dir="/PATH/TO/LOG_BASE_DIR" \
--exp_name="val_7b" \
--balance_sample \
--dataset="reason_seg" \
--sample_rates="13" \
--val_dataset "revos_valid" \
--eval_only
```
</details>
<details open>
<summary> <strong>2. Using <a href="https://github.com/dvlab-research/LLaMA-VID">LLaMA-VID</a> to generate target frame for each video</a></strong> </summary>
> You can directly download the results of our run from [OneDrive](https://mailsjlueducn-my.sharepoint.com/:u:/g/personal/yancl9918_mails_jlu_edu_cn/ETmoJF2i8ZZBsgIwdELiL8gBfptZZoPWjx6Y0eH6Myr3sw?e=mTt6rO) or [BaiduPan](https://pan.baidu.com/s/1YWs6NLPvANfhgUBHKQwnBg?pwd=visa)
- Run [http_server_mp.py](https://github.com/cilinyan/VISA/blob/main/utils_llamavid/llamavid_server.py) to build the API server for LLaMA-VID [`[demo]`](https://github.com/cilinyan/VISA/blob/c53d2cd31407eab583c5eb04f84fd95b4694f2ce/utils_llamavid/llamavid_server.py#L215-L220)
```shell
python utils_llamavid/llamavid_server.py \
--vision_tower /PATH/TO/eva_vit_g.pth \
--image_processor /PATH/TO/openai/clip-vit-large-patch14 \
--model-path /PATH/TO/YanweiLi/llama-vid-13b-full-224-video-fps-1
```
- Using the API for inference [`[demo]`](https://github.com/cilinyan/VISA/blob/c53d2cd31407eab583c5eb04f84fd95b4694f2ce/utils_llamavid/llamavid_client.py#L58-L63)
```shell
python utils_llamavid/llamavid_client.py \
--video_root /PATH/TO/ReVOS/JPEGImages \
--data_json_file /PATH/TO/ReVOS/meta_expressions_valid_.json
```
</details>
<details open>
<summary> <strong>3. Using <a href="https://github.com/cilinyan/VISA/blob/main/XMem/tracking.py">XMem</a> for mask propagation <a href="https://github.com/cilinyan/VISA/blob/c53d2cd31407eab583c5eb04f84fd95b4694f2ce/XMem/tracking.py#L103-L110">[demo]</a> </strong> </summary>
</details>
<details open>
<summary> <strong>4. Evaluate ReVOS's performance <a href="https://github.com/cilinyan/VISA/blob/main/tools/eval_revos.py#L74-L81">[demo]</a> </strong> </summary>
```shell
cd tools
python eval_revos.py /PATH/TO/FINAL_ANNOTATION [ARGS]
```
</details>
## 📑 Todo list
- [x] Release code with `Text-guided Frame Sampler`'s Local Sampling
- [ ] Release VISA model weights [issue #6](https://github.com/cilinyan/VISA/issues/6)
- [ ] Release code with `Text-guided Frame Sampler`'s Global-Local Sampling
## ⭐ Cite
If you find this project useful in your research, please consider citing:
```
@article{yan2024visa,
title={VISA: Reasoning Video Object Segmentation via Large Language Models},
author={Yan, Cilin and Wang, Haochen and Yan, Shilin and Jiang, Xiaolong and Hu, Yao and Kang, Guoliang and Xie, Weidi and Gavves, Efstratios},
journal={arXiv preprint arXiv:2407.11325},
year={2024}
}
```
## 🎖️ Acknowledgement
This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA), [SAM](https://github.com/facebookresearch/segment-anything), [LISA](https://github.com/dvlab-research/LISA), [Chat-UniVi](https://github.com/PKU-YuanGroup/Chat-UniVi), [MeViS](https://github.com/henghuiding/MeViS), [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID) and [XMem](https://github.com/hkchengrex/XMem).
================================================
FILE: XMem/dataset/__init__.py
================================================
================================================
FILE: XMem/dataset/range_transform.py
================================================
import torchvision.transforms as transforms
im_mean = (124, 116, 104)
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
inv_im_trans = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225])
================================================
FILE: XMem/dataset/reseed.py
================================================
import torch
import random
def reseed(seed):
random.seed(seed)
torch.manual_seed(seed)
================================================
FILE: XMem/dataset/static_dataset.py
================================================
import os
from os import path
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.tps import random_tps_warp
from dataset.reseed import reseed
class StaticTransformDataset(Dataset):
"""
Generate pseudo VOS data by applying random transforms on static images.
Single-object only.
Method 0 - FSS style (class/1.jpg class/1.png)
Method 1 - Others style (XXX.jpg XXX.png)
"""
def __init__(self, parameters, num_frames=3, max_num_obj=1):
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.im_list = []
for parameter in parameters:
root, method, multiplier = parameter
if method == 0:
# Get images
classes = os.listdir(root)
for c in classes:
imgs = os.listdir(path.join(root, c))
jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
joint_list = [path.join(root, c, im) for im in jpg_list]
self.im_list.extend(joint_list * multiplier)
elif method == 1:
self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
print(f'{len(self.im_list)} images found.')
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
transforms.Resize(384, InterpolationMode.BICUBIC),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
transforms.Resize(384, InterpolationMode.NEAREST),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
transforms.RandomGrayscale(0.05),
])
self.all_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
transforms.RandomHorizontalFlip(),
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
transforms.RandomHorizontalFlip(),
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
self.final_gt_transform = transforms.Compose([
transforms.ToTensor(),
])
def _get_sample(self, idx):
im = Image.open(self.im_list[idx]).convert('RGB')
gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
for _ in range(self.num_frames):
reseed(sequence_seed)
this_im = self.all_im_dual_transform(im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = self.all_gt_dual_transform(gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
# Use TPS only some of the times
# Not because TPS is bad -- just that it is too slow and I need to speed up data loading
if np.random.rand() < 0.33:
this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
this_im = self.final_im_transform(this_im)
this_gt = self.final_gt_transform(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
masks = torch.stack(masks, 0)
return images, masks.numpy()
def __getitem__(self, idx):
additional_objects = np.random.randint(self.max_num_obj)
indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
merged_images = None
merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
for i, list_id in enumerate(indices):
images, masks = self._get_sample(list_id)
if merged_images is None:
merged_images = images
else:
merged_images = merged_images*(1-masks) + images*masks
merged_masks[masks[:,0]>0.5] = (i+1)
masks = merged_masks
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
target_objects = labels.tolist()
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
info = {}
info['name'] = self.im_list[idx]
info['num_objects'] = max(1, len(target_objects))
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': merged_images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info
}
return data
def __len__(self):
return len(self.im_list)
================================================
FILE: XMem/dataset/tps.py
================================================
import numpy as np
from PIL import Image
import cv2
import thinplate as tps
cv2.setNumThreads(0)
def pick_random_points(h, w, n_samples):
y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
return y_idx/h, x_idx/w
def warp_dual_cv(img, mask, c_src, c_dst):
dshape = img.shape
theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
grid = tps.tps_grid(theta, c_dst, dshape)
mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
"""
Apply a random TPS warp of the input image and mask
Uses randomness from numpy
"""
img = np.asarray(img)
mask = np.asarray(mask)
h, w = mask.shape
points = pick_random_points(h, w, n_ctrl_pts)
c_src = np.stack(points, 1)
c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
return Image.fromarray(warp_im), Image.fromarray(warp_gt)
================================================
FILE: XMem/dataset/util.py
================================================
import numpy as np
def all_to_onehot(masks, labels):
if len(masks.shape) == 3:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
else:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
for ni, l in enumerate(labels):
Ms[ni] = (masks == l).astype(np.uint8)
return Ms
================================================
FILE: XMem/dataset/vos_dataset.py
================================================
import os
from os import path, replace
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.reseed import reseed
class VOSDataset(Dataset):
"""
Works for DAVIS/YouTubeVOS/BL30K training
For each sequence:
- Pick three frames
- Pick two objects
- Apply some random transforms that are the same for all frames
- Apply random transform to each of the frame
- The distance between frames is controlled
"""
def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
self.im_root = im_root
self.gt_root = gt_root
self.max_jump = max_jump
self.is_bl = is_bl
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.videos = []
self.frames = {}
vid_list = sorted(os.listdir(self.im_root))
# Pre-filtering
for vid in vid_list:
if subset is not None:
if vid not in subset:
continue
frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
if len(frames) < num_frames:
continue
self.frames[vid] = frames
self.videos.append(vid)
print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.01, 0.01, 0.01, 0),
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.03, 0.03, 0),
transforms.RandomGrayscale(0.05),
])
if self.is_bl:
# Use a different cropping scheme for the blender dataset because the image size is different
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
])
else:
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def __getitem__(self, idx):
video = self.videos[idx]
info = {}
info['name'] = video
vid_im_path = path.join(self.im_root, video)
vid_gt_path = path.join(self.gt_root, video)
frames = self.frames[video]
trials = 0
while trials < 5:
info['frames'] = [] # Appended with actual frames
num_frames = self.num_frames
length = len(frames)
this_max_jump = min(len(frames), self.max_jump)
# iterative sampling
frames_idx = [np.random.randint(length)]
acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
while(len(frames_idx) < num_frames):
idx = np.random.choice(list(acceptable_set))
frames_idx.append(idx)
new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
frames_idx = sorted(frames_idx)
if np.random.rand() < 0.5:
# Reverse time
frames_idx = frames_idx[::-1]
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
target_objects = []
for f_idx in frames_idx:
jpg_name = frames[f_idx][:-4] + '.jpg'
png_name = frames[f_idx][:-4] + '.png'
info['frames'].append(jpg_name)
reseed(sequence_seed)
this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
this_im = self.all_im_dual_transform(this_im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
this_gt = self.all_gt_dual_transform(this_gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
this_im = self.final_im_transform(this_im)
this_gt = np.array(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
if self.is_bl:
# Find large enough labels
good_lables = []
for l in labels:
pixel_sum = (masks[0]==l).sum()
if pixel_sum > 10*10:
# OK if the object is always this small
# Not OK if it is actually much bigger
if pixel_sum > 30*30:
good_lables.append(l)
elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
good_lables.append(l)
labels = np.array(good_lables, dtype=np.uint8)
if len(labels) == 0:
target_objects = []
trials += 1
else:
target_objects = labels.tolist()
break
if len(target_objects) > self.max_num_obj:
target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
info['num_objects'] = max(1, len(target_objects))
masks = np.stack(masks, 0)
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info,
}
return data
def __len__(self):
return len(self.videos)
================================================
FILE: XMem/eval.py
================================================
import os
from os import path
from argparse import ArgumentParser
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore
from tqdm import tqdm
try:
import hickle as hkl
except ImportError:
print('Failed to import hickle. Fine if not using multi-scale testing.')
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
parser.add_argument('--meta_exp', type=str)
# Data options
parser.add_argument('--d16_path', default='../DAVIS/2016')
parser.add_argument('--d17_path', default='../DAVIS/2017')
parser.add_argument('--y18_path', default='../YouTube2018')
parser.add_argument('--y19_path', default='../YouTube')
parser.add_argument('--lv_path', default='../long_video_set')
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
parser.add_argument('--generic_path')
parser.add_argument('--img_dir')
parser.add_argument('--reversed', action='store_true')
parser.add_argument('--split_part', type=int, default=0)
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17')
parser.add_argument('--split', help='val/test', default='val')
parser.add_argument('--output', default=None)
parser.add_argument('--save_all', action='store_true',
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
# Long-term memory options
parser.add_argument('--disable_long_term', action='store_true')
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
# Multi-scale options
parser.add_argument('--save_scores', action='store_true')
parser.add_argument('--flip', action='store_true')
parser.add_argument('--size', default=480, type=int, help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
config = vars(args)
config['enable_long_term'] = not config['disable_long_term']
if args.output is None:
args.output = f'../output/{args.dataset}_{args.split}'
print(f'Output path not provided. Defaulting to {args.output}')
"""
Data preparation
"""
is_youtube = args.dataset.startswith('Y')
is_davis = args.dataset.startswith('D')
is_lv = args.dataset.startswith('LV')
if is_youtube or args.save_scores:
out_path = path.join(args.output, 'Annotations')
else:
out_path = args.output
if is_youtube:
if args.dataset == 'Y18':
yv_path = args.y18_path
elif args.dataset == 'Y19':
yv_path = args.y19_path
if args.split == 'val':
args.split = 'valid'
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size)
elif args.split == 'test':
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size)
else:
raise NotImplementedError
elif is_davis:
if args.dataset == 'D16':
if args.split == 'val':
# Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size)
else:
raise NotImplementedError
palette = None
elif args.dataset == 'D17':
if args.split == 'val':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size)
elif args.split == 'test':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size)
else:
raise NotImplementedError
elif is_lv:
if args.dataset == 'LV1':
meta_dataset = LongTestDataset(args.meta_exp, path.join(args.lv_path, 'long_video'))
elif args.dataset == 'LV3':
meta_dataset = LongTestDataset(args.meta_exp, path.join(args.lv_path, 'long_video_x3'))
else:
raise NotImplementedError
elif args.dataset == 'G':
meta_dataset = LongTestDataset(args.meta_exp, path.join(args.generic_path), size=args.size, img_dir=args.img_dir, reversed_=args.reversed, split_part=args.split_part)
if not args.save_all:
args.save_all = True
print('save_all is forced to be true in generic evaluation mode.')
else:
raise NotImplementedError
torch.autograd.set_grad_enabled(False)
# Set up loader
meta_loader = meta_dataset.get_datasets()
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
if args.model is not None:
model_weights = torch.load(args.model)
network.load_weights(model_weights, init_as_zero_if_needed=True)
else:
print('No model loaded.')
total_process_time = 0
total_frames = 0
# Start eval
for vid_reader in tqdm(meta_loader, total=len(meta_dataset)):
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
vid_name = vid_reader.vid_name
vid_length = len(loader)
# no need to count usage for LT if the video is not that long anyway
config['enable_long_term_count_usage'] = (
config['enable_long_term'] and
(vid_length
/ (config['max_mid_term_frames']-config['min_mid_term_frames'])
* config['num_prototypes'])
>= config['max_long_term_elements']
)
mapper = MaskMapper()
processor = InferenceCore(network, config=config)
first_mask_loaded = False
for ti, data in enumerate(loader):
with torch.cuda.amp.autocast(enabled=not args.benchmark):
rgb = data['rgb'].cuda()[0]
msk = data.get('mask')
info = data['info']
frame = info['frame'][0]
shape = info['shape']
need_resize = info['need_resize'][0]
"""
For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
Seems to be very similar in testing as my previous timing method
with two cuda sync + time.time() in STCN though
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if not first_mask_loaded:
if msk is not None:
first_mask_loaded = True
else:
# no point to do anything without a mask
continue
if args.flip:
rgb = torch.flip(rgb, dims=[-1])
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
# Map possibly non-continuous labels to continuous ones
if msk is not None:
msk, labels = mapper.convert_mask(msk[0].numpy())
msk = torch.Tensor(msk).cuda()
if need_resize:
if msk.shape[0] == 0:
print(vid_name)
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
processor.set_all_labels(list(mapper.remappings.values()))
else:
labels = None
# Run the model on this frame
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
# Upsample to original size if needed
if need_resize:
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
end.record()
torch.cuda.synchronize()
total_process_time += (start.elapsed_time(end)/1000)
total_frames += 1
if args.flip:
prob = torch.flip(prob, dims=[-1])
# Probability mask -> index mask
out_mask = torch.max(prob, dim=0).indices
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
if args.save_scores:
prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
# Save the mask
if args.save_all or info['save'][0]:
this_out_path = path.join(out_path, vid_name)
os.makedirs(this_out_path, exist_ok=True)
out_mask = mapper.remap_index_mask(out_mask)
out_img = Image.fromarray(out_mask)
if vid_reader.get_palette() is not None:
out_img.putpalette(vid_reader.get_palette())
out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
if args.save_scores:
np_path = path.join(args.output, 'Scores', vid_name)
os.makedirs(np_path, exist_ok=True)
if ti==len(loader)-1:
hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w')
if args.save_all or info['save'][0]:
hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf')
print(f'Total processing time: {total_process_time}')
print(f'Total processed frames: {total_frames}')
print(f'FPS: {total_frames / total_process_time}')
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
if not args.save_scores:
if is_youtube:
print('Making zip for YouTubeVOS...')
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
elif is_davis and args.split == 'test':
print('Making zip for DAVIS test-dev...')
shutil.make_archive(args.output, 'zip', args.output)
================================================
FILE: XMem/eval_batch.py
================================================
import os
import time
import torch
import argparse
import multiprocessing as mp
from termcolor import colored
from datetime import datetime
from importlib.util import find_spec
if find_spec("GPUtil") is None: os.system("pip install gputil")
import GPUtil
_GPU_LIST = [_.id for _ in GPUtil.getGPUs()]
_GPU_QUEUE = mp.Queue()
for _ in _GPU_LIST: _GPU_QUEUE.put(_)
def run_eval(meta_expression, temp_xmem_anno, final_xmem_anno, img_dir, split_part, cfgs=" --reversed "):
gpu_id = _GPU_QUEUE.get()
cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python eval.py --meta_exp {meta_expression} --output {final_xmem_anno} --generic_path {temp_xmem_anno} --img_dir {img_dir} --split_part {split_part} --dataset G {cfgs}"
print(f"Running: {cmd}")
os.system(cmd)
_GPU_QUEUE.put(gpu_id)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--meta_expression", type=str, help='/PATH/TO/ReVOS/meta_expressions_valid__llamavid.json')
parser.add_argument("--temp_xmem_anno", type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations')
parser.add_argument("--final_xmem_anno", type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations')
parser.add_argument("--img_dir", type=str, help='/PATH/TO/ReVOS/JPEGImages')
args = parser.parse_args()
p = mp.Pool(8)
for split_part in [0, 1, 2, 3]:
for cfgs in [" ", " --reversed "]:
p.apply_async(
run_eval,
args=(args.meta_expression, args.temp_xmem_anno, args.final_xmem_anno, args.img_dir, split_part, cfgs),
error_callback=lambda e: print(colored(e, "red"))
)
p.close()
p.join()
if __name__ == "__main__":
main()
================================================
FILE: XMem/generate_xmem_data_single.py
================================================
import sys
import os
import os.path as osp
import glob
import cv2
import multiprocessing
import json
import argparse
from tqdm import tqdm
from termcolor import colored
"""
python generate_xmem_data_single.py \
--video_root /PATH/TO/VISA_exp/revos_valid/Annotations \
--output_dir /PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations \
--final_xmem_anno /PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations \
--llama_vid_meta /PATH/TO/ReVOS/meta_expressions_valid__llamavid.json
"""
def generate(obj, temp_xmem_anno, final_xmem_anno):
obj_dir, video_name, obj_id, tp = obj
img_list = glob.glob(obj_dir + '/*.png') # Mask
img_list.sort()
frame_id = int(len(img_list) * tp)
if frame_id == len(img_list):
frame_id -= 1
used_img = img_list[frame_id]
img_output_path = osp.join(temp_xmem_anno, video_name, obj_id, osp.basename(used_img))
final_img_output_dir = osp.join(final_xmem_anno, video_name, obj_id)
img_output_dir = osp.dirname(img_output_path)
os.makedirs(img_output_dir, exist_ok=True)
os.makedirs(final_img_output_dir, exist_ok=True)
os.system('cp {} {}'.format(used_img, img_output_path))
img = cv2.imread(img_output_path)
if img.sum() == 0:
target_img_list = [i.split('/')[-1] for i in img_list]
for img_ in target_img_list:
print(os.path.join(final_img_output_dir, img_))
os.system('cp {} {}'.format(img_output_path, os.path.join(img_output_dir, img_)))
os.system('cp {} {}'.format(img_output_path, os.path.join(final_img_output_dir, img_)))
return 0
def main():
parser = argparse.ArgumentParser(description='rgvos')
parser.add_argument('--video_root', type=str, help='/PATH/TO/VISA_exp/revos_valid/Annotations', )
parser.add_argument('--temp_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_temp/Annotations', ) # 保存单帧 Mask 的路径
parser.add_argument('--final_xmem_anno', type=str, help='/PATH/TO/VISA_exp/revos_valid_XMem_final/Annotations', ) # 保存 XMem 最后输出结果的路径
parser.add_argument("--llama_vid_meta", type=str, help='/PATH/TO/ReVOS/meta_expressions_valid__llamavid.json', )
args = parser.parse_args()
video_root = args.video_root
temp_xmem_anno = args.temp_xmem_anno
final_xmem_anno = args.final_xmem_anno
os.makedirs(temp_xmem_anno, exist_ok=True)
data = json.load(open(args.llama_vid_meta, 'r'))['videos']
all_obj_list = []
for video_name in data.keys():
exps = data[video_name]['expressions']
for obj_id in exps.keys():
tp = exps[obj_id]['tp']
obj_dir = os.path.join(video_root, video_name, obj_id)
all_obj_list.append([obj_dir, video_name, obj_id, tp])
print('start')
cpu_num = multiprocessing.cpu_count()-1
print("cpu_num:", cpu_num)
pool = multiprocessing.Pool(cpu_num)
pbar = tqdm(total=len(all_obj_list))
for obj in all_obj_list:
pool.apply_async(
generate,
args = (obj, temp_xmem_anno, final_xmem_anno ),
callback = lambda *a: pbar.update(1),
error_callback = lambda e: print(colored(e, "red"))
)
pool.close()
pool.join()
pbar.close()
if __name__ == '__main__':
main()
================================================
FILE: XMem/inference/__init__.py
================================================
================================================
FILE: XMem/inference/data/__init__.py
================================================
================================================
FILE: XMem/inference/data/mask_mapper.py
================================================
import numpy as np
import torch
from dataset.util import all_to_onehot
class MaskMapper:
"""
This class is used to convert a indexed-mask to a one-hot representation.
It also takes care of remapping non-continuous indices
It has two modes:
1. Default. Only masks with new indices are supposed to go into the remapper.
This is also the case for YouTubeVOS.
i.e., regions with index 0 are not "background", but "don't care".
2. Exhaustive. Regions with index 0 are considered "background".
Every single pixel is considered to be "labeled".
"""
def __init__(self):
self.labels = []
self.remappings = {}
# if coherent, no mapping is required
self.coherent = True
def convert_mask(self, mask, exhaustive=False):
# mask is in index representation, H*W numpy array
labels = np.unique(mask).astype(np.uint8)
labels = labels[labels!=0].tolist()
new_labels = list(set(labels) - set(self.labels))
if not exhaustive:
assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
# add new remappings
for i, l in enumerate(new_labels):
self.remappings[l] = i+len(self.labels)+1
if self.coherent and i+len(self.labels)+1 != l:
self.coherent = False
if exhaustive:
new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
else:
if self.coherent:
new_mapped_labels = new_labels
else:
new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
self.labels.extend(new_labels)
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
# mask num_objects*H*W
return mask, new_mapped_labels
def remap_index_mask(self, mask):
# mask is in index representation, H*W numpy array
if self.coherent:
return mask
new_mask = np.zeros_like(mask)
for l, i in self.remappings.items():
new_mask[mask==i] = l
return new_mask
================================================
FILE: XMem/inference/data/test_datasets.py
================================================
import os
from os import path
import json
import glob
from inference.data.video_reader import VideoReader
class LongTestDataset:
def __init__(self, meta_expression, data_root, size=-1, img_dir = '', reversed_ = False, split_part = 0):
self.image_dir = img_dir
self.mask_dir = data_root
self.size = size
self.reversed = reversed_
self.split_part = split_part
self.vid_list = []
videos_names = json.load(open(meta_expression, 'r'))['videos']
for video_name in videos_names:
video_mask_dir = path.join(self.mask_dir, video_name)
obj_ids = [d for d in os.listdir(video_mask_dir) if os.path.isdir(path.join(video_mask_dir, d))]
for obj_id in obj_ids:
obj_dir = path.join(video_mask_dir, obj_id)
img_list = glob.glob(obj_dir + '/*')
if len(img_list) == 1:
self.vid_list.append(path.join(video_name, obj_id))
self.vid_list.sort()
self.vid_list = [i for idx, i in enumerate(self.vid_list) if idx % 4 == self.split_part]
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, '/'.join(video.split('/')[:-1])),
path.join(self.mask_dir, video),
to_save = [
name[:-4] for name in os.listdir(path.join(self.mask_dir, video)) # remove .png
],
size=self.size,
reversed=self.reversed,
)
def __len__(self):
return len(self.vid_list)
class DAVISTestDataset:
def __init__(self, data_root, imset='2017/val.txt', size=-1):
if size != 480:
self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
if not path.exists(self.image_dir):
print(f'{self.image_dir} not found. Look at other options.')
self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
self.mask_dir = path.join(data_root, 'Annotations', '1080p')
assert path.exists(self.image_dir), 'path not found'
else:
self.image_dir = path.join(data_root, 'JPEGImages', '480p')
self.mask_dir = path.join(data_root, 'Annotations', '480p')
self.size_dir = path.join(data_root, 'JPEGImages', '480p')
self.size = size
with open(path.join(data_root, 'ImageSets', imset)) as f:
self.vid_list = sorted([line.strip() for line in f])
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
size_dir=path.join(self.size_dir, video),
)
def __len__(self):
return len(self.vid_list)
class YouTubeVOSTestDataset:
def __init__(self, data_root, split, size=480):
self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
self.mask_dir = path.join(data_root, split, 'Annotations')
self.size = size
self.vid_list = sorted(os.listdir(self.image_dir))
self.req_frame_list = {}
with open(path.join(data_root, split, 'meta.json')) as f:
# read meta.json to know which frame is required for evaluation
meta = json.load(f)['videos']
for vid in self.vid_list:
req_frames = []
objects = meta[vid]['objects']
for value in objects.values():
req_frames.extend(value['frames'])
req_frames = list(set(req_frames))
self.req_frame_list[vid] = req_frames
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
to_save=self.req_frame_list[video],
use_all_mask=True
)
def __len__(self):
return len(self.vid_list)
================================================
FILE: XMem/inference/data/video_reader.py
================================================
import os
from os import path
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization
class VideoReader(Dataset):
"""
This class is used to read a video, one frame at a time
"""
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, reversed = False):
"""
image_dir - points to a directory of jpg images
mask_dir - points to a directory of png masks
size - resize min. side to size. Does nothing if <0.
to_save - optionally contains a list of file names without extensions
where the segmentation mask is required
use_all_mask - when true, read all available mask in mask_dir.
Default false. Set to true for YouTubeVOS validation.
"""
self.vid_name = vid_name
self.image_dir = image_dir
self.mask_dir = mask_dir
self.to_save = to_save
self.use_all_mask = use_all_mask
self.reversed = reversed
if size_dir is None:
self.size_dir = self.image_dir
else:
self.size_dir = size_dir
self.frames = sorted(os.listdir(self.image_dir))
if self.reversed:
self.frames = self.frames[::-1]
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
if size < 0:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
else:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
])
self.size = size
def __getitem__(self, idx):
frame = self.frames[idx]
info = {}
data = {}
info['frame'] = frame
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
im_path = path.join(self.image_dir, frame)
img = Image.open(im_path).convert('RGB')
if self.image_dir == self.size_dir:
shape = np.array(img).shape[:2]
else:
size_path = path.join(self.size_dir, frame)
size_im = Image.open(size_path).convert('RGB')
shape = np.array(size_im).shape[:2]
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
img = self.im_transform(img)
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
if load_mask and path.exists(gt_path):
mask = Image.open(gt_path).convert('P')
mask = np.array(mask, dtype=np.uint8)
data['mask'] = mask
info['shape'] = shape
info['need_resize'] = not (self.size < 0)
data['rgb'] = img
data['info'] = info
return data
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
def get_palette(self):
return self.palette
def __len__(self):
return len(self.frames)
================================================
FILE: XMem/inference/inference_core.py
================================================
from inference.memory_manager import MemoryManager
from model.network import XMem
from model.aggregate import aggregate
from util.tensor_util import pad_divide_by, unpad
class InferenceCore:
def __init__(self, network:XMem, config):
self.config = config
self.network = network
self.mem_every = config['mem_every']
self.deep_update_every = config['deep_update_every']
self.enable_long_term = config['enable_long_term']
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = (self.deep_update_every < 0)
self.clear_memory()
self.all_labels = None
def clear_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
if not self.deep_update_sync:
self.last_deep_update_ti = -self.deep_update_every
self.memory = MemoryManager(config=self.config)
def update_config(self, config):
self.mem_every = config['mem_every']
self.deep_update_every = config['deep_update_every']
self.enable_long_term = config['enable_long_term']
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = (self.deep_update_every < 0)
self.memory.update_config(config)
def set_all_labels(self, all_labels):
# self.all_labels = [l.item() for l in all_labels]
self.all_labels = all_labels
def step(self, image, mask=None, valid_labels=None, end=False):
# image: 3*H*W
# mask: num_objects*H*W or None
self.curr_ti += 1
image, self.pad = pad_divide_by(image, 16)
image = image.unsqueeze(0) # add the batch dimension
is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
is_deep_update = (
(self.deep_update_sync and is_mem_frame) or # synchronized
(not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
) and (not end)
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
need_ek=(self.enable_long_term or need_segment),
need_sk=is_mem_frame)
multi_scale_features = (f16, f8, f4)
# segment the current frame is needed
if need_segment:
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
# remove batch dim
pred_prob_with_bg = pred_prob_with_bg[0]
pred_prob_no_bg = pred_prob_with_bg[1:]
if is_normal_update:
self.memory.set_hidden(hidden)
else:
pred_prob_no_bg = pred_prob_with_bg = None
# use the input mask if any
if mask is not None:
mask, _ = pad_divide_by(mask, 16)
if pred_prob_no_bg is not None:
# if we have a predicted mask, we work on it
# make pred_prob_no_bg consistent with the input mask
mask_regions = (mask.sum(0) > 0.5)
pred_prob_no_bg[:, mask_regions] = 0
# shift by 1 because mask/pred_prob_no_bg do not contain background
mask = mask.type_as(pred_prob_no_bg)
if valid_labels is not None:
shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
# non-labelled objects are copied from the predicted mask
mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
pred_prob_with_bg = aggregate(mask, dim=0)
# also create new hidden states
self.memory.create_hidden_state(len(self.all_labels), key)
# save as memory if needed
if is_mem_frame:
value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
self.memory.add_memory(key, shrinkage, value, self.all_labels,
selection=selection if self.enable_long_term else None)
self.last_mem_ti = self.curr_ti
if is_deep_update:
self.memory.set_hidden(hidden)
self.last_deep_update_ti = self.curr_ti
return unpad(pred_prob_with_bg, self.pad)
================================================
FILE: XMem/inference/interact/__init__.py
================================================
================================================
FILE: XMem/inference/interact/fbrs/LICENSE
================================================
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.
================================================
FILE: XMem/inference/interact/fbrs/__init__.py
================================================
================================================
FILE: XMem/inference/interact/fbrs/controller.py
================================================
import torch
try:
from torch import mps
except:
pass
from ..fbrs.inference import clicker
from ..fbrs.inference.predictors import get_predictor
class InteractiveController:
def __init__(self, net, device, predictor_params, prob_thresh=0.5):
self.net = net.to(device)
self.prob_thresh = prob_thresh
self.clicker = clicker.Clicker()
self.states = []
self.probs_history = []
self.object_count = 0
self._result_mask = None
self.image = None
self.predictor = None
self.device = device
self.predictor_params = predictor_params
self.reset_predictor()
def set_image(self, image):
self.image = image
self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8)
self.object_count = 0
self.reset_last_object()
def add_click(self, x, y, is_positive):
self.states.append({
'clicker': self.clicker.get_state(),
'predictor': self.predictor.get_states()
})
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
if self.device.type == 'cuda':
torch.cuda.empty_cache()
elif self.device.type == 'mps':
mps.empty_cache()
if self.probs_history:
self.probs_history.append((self.probs_history[-1][0], pred))
else:
self.probs_history.append((torch.zeros_like(pred), pred))
def undo_click(self):
if not self.states:
return
prev_state = self.states.pop()
self.clicker.set_state(prev_state['clicker'])
self.predictor.set_states(prev_state['predictor'])
self.probs_history.pop()
def partially_finish_object(self):
object_prob = self.current_object_prob
if object_prob is None:
return
self.probs_history.append((object_prob, torch.zeros_like(object_prob)))
self.states.append(self.states[-1])
self.clicker.reset_clicks()
self.reset_predictor()
def finish_object(self):
object_prob = self.current_object_prob
if object_prob is None:
return
self.object_count += 1
object_mask = object_prob > self.prob_thresh
self._result_mask[object_mask] = self.object_count
self.reset_last_object()
def reset_last_object(self):
self.states = []
self.probs_history = []
self.clicker.reset_clicks()
self.reset_predictor()
def reset_predictor(self, predictor_params=None):
if predictor_params is not None:
self.predictor_params = predictor_params
self.predictor = get_predictor(self.net, device=self.device,
**self.predictor_params)
if self.image is not None:
self.predictor.set_input_image(self.image)
@property
def current_object_prob(self):
if self.probs_history:
current_prob_total, current_prob_additive = self.probs_history[-1]
return torch.maximum(current_prob_total, current_prob_additive)
else:
return None
@property
def is_incomplete_mask(self):
return len(self.probs_history) > 0
@property
def result_mask(self):
return self._result_mask.clone()
================================================
FILE: XMem/inference/interact/fbrs/inference/__init__.py
================================================
================================================
FILE: XMem/inference/interact/fbrs/inference/clicker.py
================================================
from collections import namedtuple
import numpy as np
from copy import deepcopy
from scipy.ndimage import distance_transform_edt
Click = namedtuple('Click', ['is_positive', 'coords'])
class Clicker(object):
def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1):
if gt_mask is not None:
self.gt_mask = gt_mask == 1
self.not_ignore_mask = gt_mask != ignore_label
else:
self.gt_mask = None
self.reset_clicks()
if init_clicks is not None:
for click in init_clicks:
self.add_click(click)
def make_next_click(self, pred_mask):
assert self.gt_mask is not None
click = self._get_click(pred_mask)
self.add_click(click)
def get_clicks(self, clicks_limit=None):
return self.clicks_list[:clicks_limit]
def _get_click(self, pred_mask, padding=True):
fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
fn_mask_dt = distance_transform_edt(fn_mask)
fp_mask_dt = distance_transform_edt(fp_mask)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
fn_mask_dt = fn_mask_dt * self.not_clicked_map
fp_mask_dt = fp_mask_dt * self.not_clicked_map
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
if is_positive:
coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
else:
coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
def add_click(self, click):
coords = click.coords
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = False
def _remove_last_click(self):
click = self.clicks_list.pop()
coords = click.coords
if click.is_positive:
self.num_pos_clicks -= 1
else:
self.num_neg_clicks -= 1
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = True
def reset_clicks(self):
if self.gt_mask is not None:
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def get_state(self):
return deepcopy(self.clicks_list)
def set_state(self, state):
self.reset_clicks()
for click in state:
self.add_click(click)
def __len__(self):
return len(self.clicks_list)
================================================
FILE: XMem/inference/interact/fbrs/inference/evaluation.py
================================================
from time import time
import numpy as np
import torch
from ..inference import utils
from ..inference.clicker import Clicker
try:
get_ipython()
from tqdm import tqdm_notebook as tqdm
except NameError:
from tqdm import tqdm
def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs):
all_ious = []
start_time = time()
for index in tqdm(range(len(dataset)), leave=False):
sample = dataset.get_sample(index)
item = dataset[index]
if oracle_eval:
gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32)
gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
predictor.opt_functor.mask_loss.set_gt_mask(gt_mask)
_, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs)
all_ious.append(sample_ious)
end_time = time()
elapsed_time = end_time - start_time
return all_ious, elapsed_time
def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr,
pred_thr=0.49, max_clicks=20):
clicker = Clicker(gt_mask=instances_mask)
pred_mask = np.zeros_like(instances_mask)
ious_list = []
with torch.no_grad():
predictor.set_input_image(image_nd)
for click_number in range(max_clicks):
clicker.make_next_click(pred_mask)
pred_probs = predictor.get_prediction(clicker)
pred_mask = pred_probs > pred_thr
iou = utils.get_iou(instances_mask, pred_mask)
ious_list.append(iou)
if iou >= max_iou_thr:
break
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
================================================
FILE: XMem/inference/interact/fbrs/inference/predictors/__init__.py
================================================
from .base import BasePredictor
from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
from ..transforms import ZoomIn
from ...model.is_hrnet_model import DistMapsHRNetModel
def get_predictor(net, brs_mode, device,
prob_thresh=0.49,
with_flip=True,
zoom_in_params=dict(),
predictor_params=None,
brs_opt_func_params=None,
lbfgs_params=None):
lbfgs_params_ = {
'm': 20,
'factr': 0,
'pgtol': 1e-8,
'maxfun': 20,
}
predictor_params_ = {
'optimize_after_n_clicks': 1
}
if zoom_in_params is not None:
zoom_in = ZoomIn(**zoom_in_params)
else:
zoom_in = None
if lbfgs_params is not None:
lbfgs_params_.update(lbfgs_params)
lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
if brs_opt_func_params is None:
brs_opt_func_params = dict()
if brs_mode == 'NoBRS':
if predictor_params is not None:
predictor_params_.update(predictor_params)
predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
elif brs_mode.startswith('f-BRS'):
predictor_params_.update({
'net_clicks_limit': 8,
})
if predictor_params is not None:
predictor_params_.update(predictor_params)
insertion_mode = {
'f-BRS-A': 'after_c4',
'f-BRS-B': 'after_aspp',
'f-BRS-C': 'after_deeplab'
}[brs_mode]
opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params)
if isinstance(net, DistMapsHRNetModel):
FeaturePredictor = HRNetFeatureBRSPredictor
insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
else:
FeaturePredictor = FeatureBRSPredictor
predictor = FeaturePredictor(net, device,
opt_functor=opt_functor,
with_flip=with_flip,
insertion_mode=insertion_mode,
zoom_in=zoom_in,
**predictor_params_)
elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
use_dmaps = brs_mode == 'DistMap-BRS'
predictor_params_.update({
'net_clicks_limit': 5,
})
if predictor_params is not None:
predictor_params_.update(predictor_params)
opt_functor = InputOptimizer(prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params)
predictor = InputBRSPredictor(net, device,
optimize_target='dmaps' if use_dmaps else 'rgb',
opt_functor=opt_functor,
with_flip=with_flip,
zoom_in=zoom_in,
**predictor_params_)
else:
raise NotImplementedError
return predictor
================================================
FILE: XMem/inference/interact/fbrs/inference/predictors/base.py
================================================
import torch
import torch.nn.functional as F
from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
class BasePredictor(object):
def __init__(self, net, device,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
**kwargs):
self.net = net
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.device = device
self.zoom_in = zoom_in
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
def set_input_image(self, image_nd):
for transform in self.transforms:
transform.reset()
self.original_image = image_nd.to(self.device)
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
def get_prediction(self, clicker):
clicks_list = clicker.get_clicks()
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
self.original_image, [clicks_list]
)
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
size=image_nd.size()[2:])
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
print('zooming')
return self.get_prediction(clicker)
# return prediction.cpu().numpy()[0, 0]
return prediction
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
return self.net(image_nd, points_nd)['instances']
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [click.coords for click in clicks_list if click.is_positive]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)]
neg_clicks = [click.coords for click in clicks_list if not click.is_positive]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
def get_states(self):
return {'transform_states': self._get_transform_states()}
def set_states(self, states):
self._set_transform_states(states['transform_states'])
================================================
FILE: XMem/inference/interact/fbrs/inference/predictors/brs.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
from .base import BasePredictor
from ...model.is_hrnet_model import DistMapsHRNetModel
class BRSBasePredictor(BasePredictor):
def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
super().__init__(model, device, **kwargs)
self.optimize_after_n_clicks = optimize_after_n_clicks
self.opt_functor = opt_functor
self.opt_data = None
self.input_data = None
def set_input_image(self, image_nd):
super().set_input_image(image_nd)
self.opt_data = None
self.input_data = None
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
for list_indx, clicks_list in enumerate(clicks_lists):
for click in clicks_list:
y, x = click.coords
y, x = int(round(y)), int(round(x))
y1, x1 = y - radius, x - radius
y2, x2 = y + radius + 1, x + radius + 1
if click.is_positive:
pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
else:
neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
with torch.no_grad():
pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
return pos_clicks_map, neg_clicks_map
def get_states(self):
return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
def set_states(self, states):
self._set_transform_states(states['transform_states'])
self.opt_data = states['opt_data']
class FeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'after_deeplab':
self.num_channels = model.feature_extractor.ch
elif self.insertion_mode == 'after_c4':
self.num_channels = model.feature_extractor.aspp_in_channels
elif self.insertion_mode == 'after_aspp':
self.num_channels = model.feature_extractor.ch + 32
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'after_c4':
x = self.net.feature_extractor.aspp(scaled_backbone_features)
x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
align_corners=True)
x = torch.cat((x, self._c1_features), dim=1)
scaled_backbone_features = self.net.feature_extractor.head(x)
elif self.insertion_mode == 'after_aspp':
scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
pred_logits = self.net.head(scaled_backbone_features)
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
coord_features = self.net.dist_maps(image_nd, points)
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
c1, _, c3, c4 = self.net.feature_extractor.backbone(x)
c1 = self.net.feature_extractor.skip_project(c1)
if self.insertion_mode == 'after_aspp':
x = self.net.feature_extractor.aspp(c4)
x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, c1), dim=1)
backbone_features = x
else:
backbone_features = c4
self._c1_features = c1
else:
backbone_features = self.net.feature_extractor(x)[0]
return backbone_features
class HRNetFeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'A':
self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
elif self.insertion_mode == 'C':
self.num_channels = 2 * model.feature_extractor.ocr_width
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'A':
out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
pred_logits = self.net.feature_extractor.cls_head(feats)
elif self.insertion_mode == 'C':
pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
else:
raise NotImplementedError
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
coord_features = self.net.dist_maps(image_nd, points)
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
feats = self.net.feature_extractor.compute_hrnet_feats(x)
if self.insertion_mode == 'A':
backbone_features = feats
elif self.insertion_mode == 'C':
out_aux = self.net.feature_extractor.aux_head(feats)
feats = self.net.feature_extractor.conv3x3_ocr(feats)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
else:
raise NotImplementedError
return backbone_features
class InputBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.optimize_target = optimize_target
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
if self.opt_data is None or is_image_changed:
opt_channels = 2 if self.optimize_target == 'dmaps' else 3
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
device=self.device, dtype=torch.float32)
def get_prediction_logits(opt_bias):
input_image = image_nd
if self.optimize_target == 'rgb':
input_image = input_image + opt_bias
dmaps = self.net.dist_maps(input_image, points_nd)
if self.optimize_target == 'dmaps':
dmaps = dmaps + opt_bias
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
if self.optimize_target == 'all':
x = x + opt_bias
if isinstance(self.net, DistMapsHRNetModel):
pred_logits = self.net.feature_extractor(x)[0]
else:
backbone_features = self.net.feature_extractor(x)
pred_logits = self.net.head(backbone_features[0])
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
shape=self.opt_data.shape)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
**self.opt_functor.optimizer_params)
self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
================================================
FILE: XMem/inference/interact/fbrs/inference/predictors/brs_functors.py
================================================
import torch
import numpy as np
from ...model.metrics import _compute_iou
from .brs_losses import BRSMaskLoss
class BaseOptimizer:
def __init__(self, optimizer_params,
prob_thresh=0.49,
reg_weight=1e-3,
min_iou_diff=0.01,
brs_loss=BRSMaskLoss(),
with_flip=False,
flip_average=False,
**kwargs):
self.brs_loss = brs_loss
self.optimizer_params = optimizer_params
self.prob_thresh = prob_thresh
self.reg_weight = reg_weight
self.min_iou_diff = min_iou_diff
self.with_flip = with_flip
self.flip_average = flip_average
self.best_prediction = None
self._get_prediction_logits = None
self._opt_shape = None
self._best_loss = None
self._click_masks = None
self._last_mask = None
self.device = None
def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
self.best_prediction = None
self._get_prediction_logits = get_prediction_logits
self._click_masks = (pos_mask, neg_mask)
self._opt_shape = shape
self._last_mask = None
self.device = device
def __call__(self, x):
opt_params = torch.from_numpy(x).float().to(self.device)
opt_params.requires_grad_(True)
with torch.enable_grad():
opt_vars, reg_loss = self.unpack_opt_params(opt_params)
result_before_sigmoid = self._get_prediction_logits(*opt_vars)
result = torch.sigmoid(result_before_sigmoid)
pos_mask, neg_mask = self._click_masks
if self.with_flip and self.flip_average:
result, result_flipped = torch.chunk(result, 2, dim=0)
result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
loss = loss + reg_loss
f_val = loss.detach().cpu().numpy()
if self.best_prediction is None or f_val < self._best_loss:
self.best_prediction = result_before_sigmoid.detach()
self._best_loss = f_val
if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
return [f_val, np.zeros_like(x)]
current_mask = result > self.prob_thresh
if self._last_mask is not None and self.min_iou_diff > 0:
diff_iou = _compute_iou(current_mask, self._last_mask)
if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
return [f_val, np.zeros_like(x)]
self._last_mask = current_mask
loss.backward()
f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32)
return [f_val, f_grad]
def unpack_opt_params(self, opt_params):
raise NotImplementedError
class InputOptimizer(BaseOptimizer):
def unpack_opt_params(self, opt_params):
opt_params = opt_params.view(self._opt_shape)
if self.with_flip:
opt_params_flipped = torch.flip(opt_params, dims=[3])
opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
reg_loss = self.reg_weight * torch.sum(opt_params**2)
return (opt_params,), reg_loss
class ScaleBiasOptimizer(BaseOptimizer):
def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
super().__init__(*args, **kwargs)
self.scale_act = scale_act
self.reg_bias_weight = reg_bias_weight
def unpack_opt_params(self, opt_params):
scale, bias = torch.chunk(opt_params, 2, dim=0)
reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
if self.scale_act == 'tanh':
scale = torch.tanh(scale)
elif self.scale_act == 'sin':
scale = torch.sin(scale)
return (1 + scale, bias), reg_loss
================================================
FILE: XMem/inference/interact/fbrs/inference/predictors/brs_losses.py
================================================
import torch
from ...model.losses import SigmoidBinaryCrossEntropyLoss
class BRSMaskLoss(torch.nn.Module):
def __init__(self, eps=1e-5):
super().__init__()
self._eps = eps
def forward(self, result, pos_mask, neg_mask):
pos_diff = (1 - result) * pos_mask
pos_target = torch.sum(pos_diff ** 2)
pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
neg_diff = result * neg_mask
neg_target = torch.sum(neg_diff ** 2)
neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
loss = pos_target + neg_target
with torch.no_grad():
f_max_pos = torch.max(torch.abs(pos_diff)).item()
f_max_neg = torch.max(torch.abs(neg_diff)).item()
return loss, f_max_pos, f_max_neg
class OracleMaskLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.gt_mask = None
self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
self.predictor = None
self.history = []
def set_gt_mask(self, gt_mask):
self.gt_mask = gt_mask
self.history = []
def forward(self, result, pos_mask, neg_mask):
gt_mask = self.gt_mask.to(result.device)
if self.predictor.object_roi is not None:
r1, r2, c1, c2 = self.predictor.object_roi[:4]
gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
if result.shape[0] == 2:
gt_mask_flipped = torch.flip(gt_mask, dims=[3])
gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
loss = self.loss(result, gt_mask)
self.history.append(loss.detach().cpu().numpy()[0])
if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
return 0, 0, 0
return loss, 1.0, 1.0
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/__init__.py
================================================
from .base import SigmoidForPred
from .flip import AddHorizontalFlip
from .zoom_in import ZoomIn
from .limit_longest_side import LimitLongestSide
from .crops import Crops
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/base.py
================================================
import torch
class BaseTransform(object):
def __init__(self):
self.image_changed = False
def transform(self, image_nd, clicks_lists):
raise NotImplementedError
def inv_transform(self, prob_map):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
class SigmoidForPred(BaseTransform):
def transform(self, image_nd, clicks_lists):
return image_nd, clicks_lists
def inv_transform(self, prob_map):
return torch.sigmoid(prob_map)
def reset(self):
pass
def get_state(self):
return None
def set_state(self, state):
pass
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/crops.py
================================================
import math
import torch
import numpy as np
from ...inference.clicker import Click
from .base import BaseTransform
class Crops(BaseTransform):
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
super().__init__()
self.crop_height, self.crop_width = crop_size
self.min_overlap = min_overlap
self.x_offsets = None
self.y_offsets = None
self._counts = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self._counts = None
if image_height < self.crop_height or image_width < self.crop_width:
return image_nd, clicks_lists
self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
self._counts = np.zeros((image_height, image_width))
image_crops = []
for dy in self.y_offsets:
for dx in self.x_offsets:
self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
image_crops.append(image_crop)
image_crops = torch.cat(image_crops, dim=0)
self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
clicks_list = clicks_lists[0]
clicks_lists = []
for dy in self.y_offsets:
for dx in self.x_offsets:
crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx))
for x in clicks_list]
clicks_lists.append(crop_clicks)
return image_crops, clicks_lists
def inv_transform(self, prob_map):
if self._counts is None:
return prob_map
new_prob_map = torch.zeros((1, 1, *self._counts.shape),
dtype=prob_map.dtype, device=prob_map.device)
crop_indx = 0
for dy in self.y_offsets:
for dx in self.x_offsets:
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
crop_indx += 1
new_prob_map = torch.div(new_prob_map, self._counts)
return new_prob_map
def get_state(self):
return self.x_offsets, self.y_offsets, self._counts
def set_state(self, state):
self.x_offsets, self.y_offsets, self._counts = state
def reset(self):
self.x_offsets = None
self.y_offsets = None
self._counts = None
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
if length == crop_size:
return [0]
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
N = math.ceil(N)
overlap_ratio = (N - length / crop_size) / (N - 1)
overlap_width = int(crop_size * overlap_ratio)
offsets = [0]
for i in range(1, N):
new_offset = offsets[-1] + crop_size - overlap_width
if new_offset + crop_size > length:
new_offset = length - crop_size
offsets.append(new_offset)
return offsets
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/flip.py
================================================
import torch
from ..clicker import Click
from .base import BaseTransform
class AddHorizontalFlip(BaseTransform):
def transform(self, image_nd, clicks_lists):
assert len(image_nd.shape) == 4
image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
image_width = image_nd.shape[3]
clicks_lists_flipped = []
for clicks_list in clicks_lists:
clicks_list_flipped = [Click(is_positive=click.is_positive,
coords=(click.coords[0], image_width - click.coords[1] - 1))
for click in clicks_list]
clicks_lists_flipped.append(clicks_list_flipped)
clicks_lists = clicks_lists + clicks_lists_flipped
return image_nd, clicks_lists
def inv_transform(self, prob_map):
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
num_maps = prob_map.shape[0] // 2
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
def get_state(self):
return None
def set_state(self, state):
pass
def reset(self):
pass
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py
================================================
from .zoom_in import ZoomIn, get_roi_image_nd
class LimitLongestSide(ZoomIn):
def __init__(self, max_size=800):
super().__init__(target_size=max_size, skip_clicks=0)
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_max_size = max(image_nd.shape[2:4])
self.image_changed = False
if image_max_size <= self.target_size:
return image_nd, clicks_lists
self._input_image = image_nd
self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_lists[0])]
return self._roi_image, tclicks_lists
================================================
FILE: XMem/inference/interact/fbrs/inference/transforms/zoom_in.py
================================================
import torch
from ..clicker import Click
from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
from .base import BaseTransform
class ZoomIn(BaseTransform):
def __init__(self,
target_size=400,
skip_clicks=1,
expansion_ratio=1.4,
min_crop_size=200,
recompute_thresh_iou=0.5,
prob_thresh=0.50):
super().__init__()
self.target_size = target_size
self.min_crop_size = min_crop_size
self.skip_clicks = skip_clicks
self.expansion_ratio = expansion_ratio
self.recompute_thresh_iou = recompute_thresh_iou
self.prob_thresh = prob_thresh
self._input_image_shape = None
self._prev_probs = None
self._object_roi = None
self._roi_image = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
self.image_changed = False
clicks_list = clicks_lists[0]
if len(clicks_list) <= self.skip_clicks:
return image_nd, clicks_lists
self._input_image_shape = image_nd.shape
current_object_roi = None
if self._prev_probs is not None:
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if current_pred_mask.sum() > 0:
current_object_roi = get_object_roi(current_pred_mask, clicks_list,
self.expansion_ratio, self.min_crop_size)
if current_object_roi is None:
return image_nd, clicks_lists
update_object_roi = False
if self._object_roi is None:
update_object_roi = True
elif not check_object_roi(self._object_roi, clicks_list):
update_object_roi = True
elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
update_object_roi = True
if update_object_roi:
self._object_roi = current_object_roi
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_list)]
return self._roi_image.to(image_nd.device), tclicks_lists
def inv_transform(self, prob_map):
if self._object_roi is None:
self._prev_probs = prob_map.cpu().numpy()
return prob_map
assert prob_map.shape[0] == 1
rmin, rmax, cmin, cmax = self._object_roi
prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
mode='bilinear', align_corners=True)
if self._prev_probs is not None:
new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
else:
new_prob_map = prob_map
self._prev_probs = new_prob_map.cpu().numpy()
return new_prob_map
def check_possible_recalculation(self):
if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
return False
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if pred_mask.sum() > 0:
possible_object_roi = get_object_roi(pred_mask, [],
self.expansion_ratio, self.min_crop_size)
image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
return True
return False
def get_state(self):
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
def set_state(self, state):
self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
def reset(self):
self._input_image_shape = None
self._object_roi = None
self._prev_probs = None
self._roi_image = None
self.image_changed = False
def _transform_clicks(self, clicks_list):
if self._object_roi is None:
return clicks_list
rmin, rmax, cmin, cmax = self._object_roi
crop_height, crop_width = self._roi_image.shape[2:]
transformed_clicks = []
for click in clicks_list:
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c)))
return transformed_clicks
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
pred_mask = pred_mask.copy()
for click in clicks_list:
if click.is_positive:
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
bbox = get_bbox_from_mask(pred_mask)
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
h, w = pred_mask.shape[0], pred_mask.shape[1]
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
return bbox
def get_roi_image_nd(image_nd, object_roi, target_size):
rmin, rmax, cmin, cmax = object_roi
height = rmax - rmin + 1
width = cmax - cmin + 1
if isinstance(target_size, tuple):
new_height, new_width = target_size
else:
scale = target_size / max(height, width)
new_height = int(round(height * scale))
new_width = int(round(width * scale))
with torch.no_grad():
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
mode='bilinear', align_corners=True)
return roi_image_nd
def check_object_roi(object_roi, clicks_list):
for click in clicks_list:
if click.is_positive:
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
return False
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
return False
return True
================================================
FILE: XMem/inference/interact/fbrs/inference/utils.py
================================================
from datetime import timedelta
from pathlib import Path
import torch
import numpy as np
from ..model.is_deeplab_model import get_deeplab_model
from ..model.is_hrnet_model import get_hrnet_model
def get_time_metrics(all_ious, elapsed_time):
n_images = len(all_ious)
n_clicks = sum(map(len, all_ious))
mean_spc = elapsed_time / n_clicks
mean_spi = elapsed_time / n_images
return mean_spc, mean_spi
def load_is_model(checkpoint, device, backbone='auto', **kwargs):
if isinstance(checkpoint, (str, Path)):
state_dict = torch.load(checkpoint, map_location='cpu')
else:
state_dict = checkpoint
if backbone == 'auto':
for k in state_dict.keys():
if 'feature_extractor.stage2.0.branches' in k:
return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
elif 'resnet' in backbone:
return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
elif 'hrnet' in backbone:
return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
else:
raise NotImplementedError('Unknown backbone')
def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256,
small=False, cpu_dist_maps=False, norm_radius=260):
if backbone == 'auto':
num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x])
small = num_fe_weights < 1800
ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k]
assert len(ocr_f_down) == 1
ocr_width = ocr_f_down[0].shape[0]
s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k]
assert len(s2_conv1_w) == 1
width = s2_conv1_w[0].shape[0]
model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small,
with_aux_output=False, cpu_dist_maps=cpu_dist_maps,
norm_radius=norm_radius)
model.load_state_dict(state_dict, strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2,
cpu_dist_maps=False, norm_radius=260):
if backbone == 'auto':
num_backbone_params = len([x for x in state_dict.keys()
if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)])
if num_backbone_params <= 181:
backbone = 'resnet34'
elif num_backbone_params <= 276:
backbone = 'resnet50'
elif num_backbone_params <= 531:
backbone = 'resnet101'
else:
raise NotImplementedError('Unknown backbone')
if 'aspp_dropout' in state_dict:
aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy())
else:
aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0]
deeplab_ch = aspp_project_weight.size(0)
if deeplab_ch == 256:
aspp_dropout = 0.5
model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch,
aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps,
norm_radius=norm_radius)
model.load_state_dict(state_dict, strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def get_iou(gt_mask, pred_mask, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
obj_gt_mask = gt_mask == 1
intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
return intersection / union
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
noc_list = []
over_max_list = []
for iou_thr in iou_thrs:
scores_arr = np.array([_get_noc(iou_arr, iou_thr)
for iou_arr in all_ious], dtype=np.int32)
score = scores_arr.mean()
over_max = (scores_arr == max_clicks).sum()
noc_list.append(score)
over_max_list.append(over_max)
return noc_list, over_max_list
def find_checkpoint(weights_folder, checkpoint_name):
weights_folder = Path(weights_folder)
if ':' in checkpoint_name:
model_name, checkpoint_name = checkpoint_name.split(':')
models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
assert len(models_candidates) == 1
model_folder = models_candidates[0]
else:
model_folder = weights_folder
if checkpoint_name.endswith('.pth'):
if Path(checkpoint_name).exists():
checkpoint_path = checkpoint_name
else:
checkpoint_path = weights_folder / checkpoint_name
else:
model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
assert len(model_checkpoints) == 1
checkpoint_path = model_checkpoints[0]
return str(checkpoint_path)
def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
n_clicks=20, model_name=None):
table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
f'{"SPC,s":^7}|{"Time":^9}|')
row_width = len(table_header)
header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
header += '-' * row_width + '\n'
header += table_header + '\n' + '-' * row_width
eval_time = str(timedelta(seconds=int(elapsed_time)))
table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
table_row += f'{noc_list[0]:^9.2f}|'
table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
return header, table_row
================================================
FILE: XMem/inference/interact/fbrs/model/__init__.py
================================================
================================================
FILE: XMem/inference/interact/fbrs/model/initializer.py
================================================
import torch
import torch.nn as nn
import numpy as np
class Initializer(object):
def __init__(self, local_init=True, gamma=None):
self.local_init = local_init
self.gamma = gamma
def __call__(self, m):
if getattr(m, '__initialized', False):
return
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
if m.weight is not None:
self._init_gamma(m.weight.data)
if m.bias is not None:
self._init_beta(m.bias.data)
else:
if getattr(m, 'weight', None) is not None:
self._init_weight(m.weight.data)
if getattr(m, 'bias', None) is not None:
self._init_bias(m.bias.data)
if self.local_init:
object.__setattr__(m, '__initialized', True)
def _init_weight(self, data):
nn.init.uniform_(data, -0.07, 0.07)
def _init_bias(self, data):
nn.init.constant_(data, 0)
def _init_gamma(self, data):
if self.gamma is None:
nn.init.constant_(data, 1.0)
else:
nn.init.normal_(data, 1.0, self.gamma)
def _init_beta(self, data):
nn.init.constant_(data, 0)
class Bilinear(Initializer):
def __init__(self, scale, groups, in_channels, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.groups = groups
self.in_channels = in_channels
def _init_weight(self, data):
"""Reset the weight and bias."""
bilinear_kernel = self.get_bilinear_kernel(self.scale)
weight = torch.zeros_like(data)
for i in range(self.in_channels):
if self.groups == 1:
j = i
else:
j = 0
weight[i, j] = bilinear_kernel
data[:] = weight
@staticmethod
def get_bilinear_kernel(scale):
"""Generate a bilinear upsampling kernel."""
kernel_size = 2 * scale - scale % 2
scale = (kernel_size + 1) // 2
center = scale - 0.5 * (1 + kernel_size % 2)
og = np.ogrid[:kernel_size, :kernel_size]
kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
return torch.tensor(kernel, dtype=torch.float32)
class XavierGluon(Initializer):
def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
super().__init__(**kwargs)
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = float(magnitude)
def _init_weight(self, arr):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
if self.factor_type == 'avg':
factor = (fan_in + fan_out) / 2.0
elif self.factor_type == 'in':
factor = fan_in
elif self.factor_type == 'out':
factor = fan_out
else:
raise ValueError('Incorrect factor type')
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == 'uniform':
nn.init.uniform_(arr, -scale, scale)
elif self.rnd_type == 'gaussian':
nn.init.normal_(arr, 0, scale)
else:
raise ValueError('Unknown random type')
================================================
FILE: XMem/inference/interact/fbrs/model/is_deeplab_model.py
================================================
import torch
import torch.nn as nn
from .ops import DistMaps
from .modeling.deeplab_v3 import DeepLabV3Plus
from .modeling.basic_blocks import SepConvHead
def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
norm_layer=nn.BatchNorm2d, backbone_norm_layer=None,
use_rgb_conv=True, cpu_dist_maps=False,
norm_radius=260):
model = DistMapsModel(
feature_extractor=DeepLabV3Plus(backbone=backbone,
ch=deeplab_ch,
project_dropout=aspp_dropout,
norm_layer=norm_layer,
backbone_norm_layer=backbone_norm_layer),
head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
num_layers=2, norm_layer=norm_layer),
use_rgb_conv=use_rgb_conv,
norm_layer=norm_layer,
norm_radius=norm_radius,
cpu_dist_maps=cpu_dist_maps
)
return model
class DistMapsModel(nn.Module):
def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True,
cpu_dist_maps=False, norm_radius=260):
super(DistMapsModel, self).__init__()
if use_rgb_conv:
self.rgb_conv = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(8),
nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
)
else:
self.rgb_conv = None
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
cpu_mode=cpu_dist_maps)
self.feature_extractor = feature_extractor
self.head = head
def forward(self, image, points):
coord_features = self.dist_maps(image, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
else:
c1, c2 = torch.chunk(coord_features, 2, dim=1)
c3 = torch.ones_like(c1)
coord_features = torch.cat((c1, c2, c3), dim=1)
x = 0.8 * image * coord_features + 0.2 * image
backbone_features = self.feature_extractor(x)
instance_out = self.head(backbone_features[0])
instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
return {'instances': instance_out}
def load_weights(self, path_to_weights):
current_state_dict = self.state_dict()
new_state_dict = torch.load(path_to_weights, map_location='cpu')
current_state_dict.update(new_state_dict)
self.load_state_dict(current_state_dict)
def get_trainable_params(self):
backbone_params = nn.ParameterList()
other_params = nn.ParameterList()
for name, param in self.named_parameters():
if param.requires_grad:
if 'backbone' in name:
backbone_params.append(param)
else:
other_params.append(param)
return backbone_params, other_params
================================================
FILE: XMem/inference/interact/fbrs/model/is_hrnet_model.py
================================================
import torch
import torch.nn as nn
from .ops import DistMaps
from .modeling.hrnet_ocr import HighResolutionNet
def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260,
use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False,
norm_layer=nn.BatchNorm2d):
model = DistMapsHRNetModel(
feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
num_classes=1, norm_layer=norm_layer),
use_rgb_conv=use_rgb_conv,
with_aux_output=with_aux_output,
norm_layer=norm_layer,
norm_radius=norm_radius,
cpu_dist_maps=cpu_dist_maps
)
return model
class DistMapsHRNetModel(nn.Module):
def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False,
norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False):
super(DistMapsHRNetModel, self).__init__()
self.with_aux_output = with_aux_output
if use_rgb_conv:
self.rgb_conv = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(8),
nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
)
else:
self.rgb_conv = None
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps)
self.feature_extractor = feature_extractor
def forward(self, image, points):
coord_features = self.dist_maps(image, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
else:
c1, c2 = torch.chunk(coord_features, 2, dim=1)
c3 = torch.ones_like(c1)
coord_features = torch.cat((c1, c2, c3), dim=1)
x = 0.8 * image * coord_features + 0.2 * image
feature_extractor_out = self.feature_extractor(x)
instance_out = feature_extractor_out[0]
instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
outputs = {'instances': instance_out}
if self.with_aux_output:
instance_aux_out = feature_extractor_out[1]
instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
outputs['instances_aux'] = instance_aux_out
return outputs
def load_weights(self, path_to_weights):
current_state_dict = self.state_dict()
new_state_dict = torch.load(path_to_weights)
current_state_dict.update(new_state_dict)
self.load_state_dict(current_state_dict)
def get_trainable_params(self):
backbone_params = nn.ParameterList()
other_params = nn.ParameterList()
other_params_keys = []
nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head']
for name, param in self.named_parameters():
if param.requires_grad:
if any(x in name for x in nonbackbone_keywords):
other_params.append(param)
other_params_keys.append(name)
else:
backbone_params.append(param)
print('Nonbackbone params:', sorted(other_params_keys))
return backbone_params, other_params
================================================
FILE: XMem/inference/interact/fbrs/model/losses.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import misc
class NormalizedFocalLossSigmoid(nn.Module):
def __init__(self, axis=-1, alpha=0.25, gamma=2,
from_logits=False, batch_axis=0,
weight=None, size_average=True, detach_delimeter=True,
eps=1e-12, scale=1.0,
ignore_label=-1):
super(NormalizedFocalLossSigmoid, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._scale = scale
self._from_logits = from_logits
self._eps = eps
self._size_average = size_average
self._detach_delimeter = detach_delimeter
self._k_sum = 0
def forward(self, pred, label, sample_weight=None):
one_hot = label > 0
sample_weight = label != self._ignore_label
if not self._from_logits:
pred = torch.sigmoid(pred)
alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
pt = torch.where(one_hot, pred, 1 - pred)
pt = torch.where(sample_weight, pt, torch.ones_like(pt))
beta = (1 - pt) ** self._gamma
sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
mult = sw_sum / (beta_sum + self._eps)
if self._detach_delimeter:
mult = mult.detach()
beta = beta * mult
ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
if np.any(ignore_area == 0):
self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
loss = self._weight * (loss * sample_weight)
if self._size_average:
bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
else:
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
return self._scale * loss
def log_states(self, sw, name, global_step):
sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
class FocalLoss(nn.Module):
def __init__(self, axis=-1, alpha=0.25, gamma=2,
from_logits=False, batch_axis=0,
weight=None, num_class=None,
eps=1e-9, size_average=True, scale=1.0):
super(FocalLoss, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._scale = scale
self._num_class = num_class
self._from_logits = from_logits
self._eps = eps
self._size_average = size_average
def forward(self, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.sigmoid(pred)
one_hot = label > 0
pt = torch.where(one_hot, pred, 1 - pred)
t = label != -1
alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
beta = (1 - pt) ** self._gamma
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
sample_weight = label != -1
loss = self._weight * (loss * sample_weight)
if self._size_average:
tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
else:
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
return self._scale * loss
class SigmoidBinaryCrossEntropyLoss(nn.Module):
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
super(SigmoidBinaryCrossEntropyLoss, self).__init__()
self._from_sigmoid = from_sigmoid
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
def forward(self, pred, label):
label = label.view(pred.size())
sample_weight = label != self._ignore_label
label = torch.where(sample_weight, label, torch.zeros_like(label))
if not self._from_sigmoid:
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
else:
eps = 1e-12
loss = -(torch.log(pred + eps) * label
+ torch.log(1. - pred + eps) * (1. - label))
loss = self._weight * (loss * sample_weight)
return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
================================================
FILE: XMem/inference/interact/fbrs/model/metrics.py
================================================
import torch
import numpy as np
from ..utils import misc
class TrainMetric(object):
def __init__(self, pred_outputs, gt_outputs):
self.pred_outputs = pred_outputs
self.gt_outputs = gt_outputs
def update(self, *args, **kwargs):
raise NotImplementedError
def get_epoch_value(self):
raise NotImplementedError
def reset_epoch_stats(self):
raise NotImplementedError
def log_states(self, sw, tag_prefix, global_step):
pass
@property
def name(self):
return type(self).__name__
class AdaptiveIoU(TrainMetric):
def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
ignore_label=-1, from_logits=True,
pred_output='instances', gt_output='instances'):
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
self._ignore_label = ignore_label
self._from_logits = from_logits
self._iou_thresh = init_thresh
self._thresh_step = thresh_step
self._thresh_beta = thresh_beta
self._iou_beta = iou_beta
self._ema_iou = 0.0
self._epoch_iou_sum = 0.0
self._epoch_batch_count = 0
def update(self, pred, gt):
gt_mask = gt > 0
if self._from_logits:
pred = torch.sigmoid(pred)
gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
if np.all(gt_mask_area == 0):
return
ignore_mask = gt == self._ignore_label
max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
best_thresh = self._iou_thresh
for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
if temp_iou > max_iou:
max_iou = temp_iou
best_thresh = t
self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
self._epoch_iou_sum += max_iou
self._epoch_batch_count += 1
def get_epoch_value(self):
if self._epoch_batch_count > 0:
return self._epoch_iou_sum / self._epoch_batch_count
else:
return 0.0
def reset_epoch_stats(self):
self._epoch_iou_sum = 0.0
self._epoch_batch_count = 0
def log_states(self, sw, tag_prefix, global_step):
sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
@property
def iou_thresh(self):
return self._iou_thresh
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
if ignore_mask is not None:
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
nonzero = union > 0
iou = intersection[nonzero] / union[nonzero]
if not keep_ignore:
return iou
else:
result = np.full_like(intersection, -1)
result[nonzero] = iou
return result
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/__init__.py
================================================
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/basic_blocks.py
================================================
import torch.nn as nn
from ...model import ops
class ConvHead(nn.Module):
def __init__(self, out_channels, in_channels=32, num_layers=1,
kernel_size=3, padding=1,
norm_layer=nn.BatchNorm2d):
super(ConvHead, self).__init__()
convhead = []
for i in range(num_layers):
convhead.extend([
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
nn.ReLU(),
norm_layer(in_channels) if norm_layer is not None else nn.Identity()
])
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
self.convhead = nn.Sequential(*convhead)
def forward(self, *inputs):
return self.convhead(inputs[0])
class SepConvHead(nn.Module):
def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
norm_layer=nn.BatchNorm2d):
super(SepConvHead, self).__init__()
sepconvhead = []
for i in range(num_layers):
sepconvhead.append(
SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
out_channels=mid_channels,
dw_kernel=kernel_size, dw_padding=padding,
norm_layer=norm_layer, activation='relu')
)
if dropout_ratio > 0 and dropout_indx == i:
sepconvhead.append(nn.Dropout(dropout_ratio))
sepconvhead.append(
nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
)
self.layers = nn.Sequential(*sepconvhead)
def forward(self, *inputs):
x = inputs[0]
return self.layers(x)
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
activation=None, use_bias=False, norm_layer=None):
super(SeparableConv2d, self).__init__()
_activation = ops.select_activation_function(activation)
self.body = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
padding=dw_padding, bias=use_bias, groups=in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
_activation()
)
def forward(self, x):
return self.body(x)
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py
================================================
from contextlib import ExitStack
import torch
from torch import nn
import torch.nn.functional as F
from .basic_blocks import SeparableConv2d
from .resnet import ResNetBackbone
from ...model import ops
class DeepLabV3Plus(nn.Module):
def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
backbone_norm_layer=None,
ch=256,
project_dropout=0.5,
inference_mode=False,
**kwargs):
super(DeepLabV3Plus, self).__init__()
if backbone_norm_layer is None:
backbone_norm_layer = norm_layer
self.backbone_name = backbone
self.norm_layer = norm_layer
self.backbone_norm_layer = backbone_norm_layer
self.inference_mode = False
self.ch = ch
self.aspp_in_channels = 2048
self.skip_project_in_channels = 256 # layer 1 out_channels
self._kwargs = kwargs
if backbone == 'resnet34':
self.aspp_in_channels = 512
self.skip_project_in_channels = 64
self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
norm_layer=self.backbone_norm_layer, **kwargs)
self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
norm_layer=self.norm_layer)
self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
self.aspp = _ASPP(in_channels=self.aspp_in_channels,
atrous_rates=[12, 24, 36],
out_channels=ch,
project_dropout=project_dropout,
norm_layer=self.norm_layer)
if inference_mode:
self.set_prediction_mode()
def load_pretrained_weights(self):
pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
norm_layer=self.backbone_norm_layer, **self._kwargs)
backbone_state_dict = self.backbone.state_dict()
pretrained_state_dict = pretrained.state_dict()
backbone_state_dict.update(pretrained_state_dict)
self.backbone.load_state_dict(backbone_state_dict)
if self.inference_mode:
for param in self.backbone.parameters():
param.requires_grad = False
def set_prediction_mode(self):
self.inference_mode = True
self.eval()
def forward(self, x):
with ExitStack() as stack:
if self.inference_mode:
stack.enter_context(torch.no_grad())
c1, _, c3, c4 = self.backbone(x)
c1 = self.skip_project(c1)
x = self.aspp(c4)
x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, c1), dim=1)
x = self.head(x)
return x,
class _SkipProject(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
super(_SkipProject, self).__init__()
_activation = ops.select_activation_function("relu")
self.skip_project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
norm_layer(out_channels),
_activation()
)
def forward(self, x):
return self.skip_project(x)
class _DeepLabHead(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
super(_DeepLabHead, self).__init__()
self.block = nn.Sequential(
SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
dw_padding=1, activation='relu', norm_layer=norm_layer),
SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
dw_padding=1, activation='relu', norm_layer=norm_layer),
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
)
def forward(self, x):
return self.block(x)
class _ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256,
project_dropout=0.5, norm_layer=nn.BatchNorm2d):
super(_ASPP, self).__init__()
b0 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
rate1, rate2, rate3 = tuple(atrous_rates)
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
project = [
nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
]
if project_dropout > 0:
project.append(nn.Dropout(project_dropout))
self.project = nn.Sequential(*project)
def forward(self, x):
x = torch.cat([block(x) for block in self.concurent], dim=1)
return self.project(x)
class _AsppPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
super(_AsppPooling, self).__init__()
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
def forward(self, x):
pool = self.gap(x)
return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=3, padding=atrous_rate,
dilation=atrous_rate, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
return block
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
from .ocr import SpatialOCR_Module, SpatialGather_Module
from .resnetv1b import BasicBlockV1b, BottleneckV1b
relu_inplace = True
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method,multi_scale_output=True,
norm_layer=nn.BatchNorm2d, align_corners=True):
super(HighResolutionModule, self).__init__()
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.norm_layer = norm_layer
self.align_corners = align_corners
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=relu_inplace)
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
self.norm_layer(num_channels[branch_index] * block.expansion),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride,
downsample=downsample, norm_layer=self.norm_layer))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index],
norm_layer=self.norm_layer))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(in_channels=num_inchannels[j],
out_channels=num_inchannels[i],
kernel_size=1,
bias=False),
self.norm_layer(num_inchannels[i])))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(num_outchannels_conv3x3)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(num_outchannels_conv3x3),
nn.ReLU(inplace=relu_inplace)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode='bilinear', align_corners=self.align_corners)
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
class HighResolutionNet(nn.Module):
def __init__(self, width, num_classes, ocr_width=256, small=False,
norm_layer=nn.BatchNorm2d, align_corners=True):
super(HighResolutionNet, self).__init__()
self.norm_layer = norm_layer
self.width = width
self.ocr_width = ocr_width
self.align_corners = align_corners
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = norm_layer(64)
self.relu = nn.ReLU(inplace=relu_inplace)
num_blocks = 2 if small else 4
stage1_num_channels = 64
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
self.stage2_num_branches = 2
num_channels = [width, 2 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[stage1_out_channel], num_inchannels)
self.stage2, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
num_blocks=2 * [num_blocks], num_channels=num_channels)
self.stage3_num_branches = 3
num_channels = [width, 2 * width, 4 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_inchannels)
self.stage3, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels,
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
num_blocks=3 * [num_blocks], num_channels=num_channels)
self.stage4_num_branches = 4
num_channels = [width, 2 * width, 4 * width, 8 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_inchannels)
self.stage4, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
num_branches=self.stage4_num_branches,
num_blocks=4 * [num_blocks], num_channels=num_channels)
last_inp_channels = np.int32(np.sum(pre_stage_channels))
ocr_mid_channels = 2 * ocr_width
ocr_key_channels = ocr_width
self.conv3x3_ocr = nn.Sequential(
nn.Conv2d(last_inp_channels, ocr_mid_channels,
kernel_size=3, stride=1, padding=1),
norm_layer(ocr_mid_channels),
nn.ReLU(inplace=relu_inplace),
)
self.ocr_gather_head = SpatialGather_Module(num_classes)
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
key_channels=ocr_key_channels,
out_channels=ocr_mid_channels,
scale=1,
dropout=0.05,
norm_layer=norm_layer,
align_corners=align_corners)
self.cls_head = nn.Conv2d(
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
self.aux_head = nn.Sequential(
nn.Conv2d(last_inp_channels, last_inp_channels,
kernel_size=1, stride=1, padding=0),
norm_layer(last_inp_channels),
nn.ReLU(inplace=relu_inplace),
nn.Conv2d(last_inp_channels, num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
)
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
self.norm_layer(num_channels_cur_layer[i]),
nn.ReLU(inplace=relu_inplace)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(inchannels, outchannels,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(outchannels),
nn.ReLU(inplace=relu_inplace)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
self.norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(inplanes, planes, stride,
downsample=downsample, norm_layer=self.norm_layer))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
return nn.Sequential(*layers)
def _make_stage(self, block, num_inchannels,
num_modules, num_branches, num_blocks, num_channels,
fuse_method='SUM',
multi_scale_output=True):
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output,
norm_layer=self.norm_layer,
align_corners=self.align_corners)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
feats = self.compute_hrnet_feats(x)
out_aux = self.aux_head(feats)
feats = self.conv3x3_ocr(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
return [out, out_aux]
def compute_hrnet_feats(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_num_branches):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_num_branches):
if self.transition2[i] is not None:
if i < self.stage2_num_branches:
x_list.append(self.transition2[i](y_list[i]))
else:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_num_branches):
if self.transition3[i] is not None:
if i < self.stage3_num_branches:
x_list.append(self.transition3[i](y_list[i]))
else:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
return torch.cat([x[0], x1, x2, x3], 1)
def load_pretrained_weights(self, pretrained_path=''):
model_dict = self.state_dict()
if not os.path.exists(pretrained_path):
print(f'\nFile "{pretrained_path}" does not exist.')
print('You need to specify the correct path to the pre-trained weights.\n'
'You can download the weights for HRNet from the repository:\n'
'https://github.com/HRNet/HRNet-Image-Classification')
exit(1)
pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
pretrained_dict.items()}
print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict))))
print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict))))
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/ocr.py
================================================
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=0, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats = feats.permute(0, 2, 1) # batch x hw x c
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
ocr_context = torch.matmul(probs, feats) \
.permute(0, 2, 1).unsqueeze(3) # batch x k x c
return ocr_context
class SpatialOCR_Module(nn.Module):
"""
Implementation of the OCR module:
We aggregate the global object representation to update the representation for each pixel.
"""
def __init__(self,
in_channels,
key_channels,
out_channels,
scale=1,
dropout=0.1,
norm_layer=nn.BatchNorm2d,
align_corners=True):
super(SpatialOCR_Module, self).__init__()
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
norm_layer, align_corners)
_in_channels = 2 * in_channels
self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
nn.Dropout2d(dropout)
)
def forward(self, feats, proxy_feats):
context = self.object_context_block(feats, proxy_feats)
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
return output
class ObjectAttentionBlock2D(nn.Module):
'''
The basic implementation for object context block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
scale : choose the scale to downsample the input feature maps (save memory cost)
bn_type : specify the bn type
Return:
N X C X H X W
'''
def __init__(self,
in_channels,
key_channels,
scale=1,
norm_layer=nn.BatchNorm2d,
align_corners=True):
super(ObjectAttentionBlock2D, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.key_channels = key_channels
self.align_corners = align_corners
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_pixel = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_object = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_down = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_up = nn.Sequential(
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
)
def forward(self, x, proxy):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
value = value.permute(0, 2, 1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
# add bg context ...
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.key_channels, *x.size()[2:])
context = self.f_up(context)
if self.scale > 1:
context = F.interpolate(input=context, size=(h, w),
mode='bilinear', align_corners=self.align_corners)
return context
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/resnet.py
================================================
import torch
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
class ResNetBackbone(torch.nn.Module):
def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
super(ResNetBackbone, self).__init__()
if backbone == 'resnet34':
pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet50':
pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet101':
pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet152':
pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
else:
raise RuntimeError(f'unknown backbone: {backbone}')
self.conv1 = pretrained.conv1
self.bn1 = pretrained.bn1
self.relu = pretrained.relu
self.maxpool = pretrained.maxpool
self.layer1 = pretrained.layer1
self.layer2 = pretrained.layer2
self.layer3 = pretrained.layer3
self.layer4 = pretrained.layer4
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
c1 = self.layer1(x)
c2 = self.layer2(c1)
c3 = self.layer3(c2)
c4 = self.layer4(c3)
return c1, c2, c3, c4
================================================
FILE: XMem/inference/interact/fbrs/model/modeling/resnetv1b.py
================================================
import torch
import torch.nn as nn
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
class BasicBlockV1b(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlockV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=previous_dilation, dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class BottleneckV1b(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BottleneckV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class ResNetV1b(nn.Module):
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
norm_layer : object
Normalization layer used (default: :class:`nn.BatchNorm2d`)
deep_stem : bool, default False
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
final_drop : float, default 0.0
Dropout ratio before the final classification layer.
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
self.inplanes = stem_width*2 if deep_stem else 64
super(ResNetV1b, self).__init__()
if not deep_stem:
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(True),
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(True),
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
norm_layer=norm_layer)
if dilated:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
avg_down=avg_down, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
avg_down=avg_down, norm_layer=norm_layer)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
avg_down=avg_down, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
avg_down=avg_down, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.drop = None
if final_drop > 0.0:
self.drop = nn.Dropout(final_drop)
self.fc = nn.Linear(512 * block.expansion, classes)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
avg_down=False, norm_layer=nn.BatchNorm2d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = []
if avg_down:
if dilation == 1:
downsample.append(
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
)
else:
downsample.append(
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
)
downsample.extend([
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
kernel_size=1, stride=1, bias=False),
norm_layer(planes * block.expansion)
])
downsample = nn.Sequential(*downsample)
else:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion)
)
layers = []
if dilation in (1, 2):
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation,
previous_dilation=dilation, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
if self.drop is not None:
x = self.drop(x)
x = self.fc(x)
return x
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
filtered_orig_dict = {}
for k, v in orig_dict.items():
if k in model_dict_keys:
filtered_orig_dict[k] = v
else:
print(f"[ERROR] Failed to load <{k}> in backbone")
return filtered_orig_dict
def resnet34_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet50_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet101_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet152_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
================================================
FILE: XMem/inference/interact/fbrs/model/ops.py
================================================
import torch
from torch import nn as nn
import numpy as np
from . import initializer as initializer
from ..utils.cython import get_dist_maps
def select_activation_function(activation):
if isinstance(activation, str):
if activation.lower() == 'relu':
return nn.ReLU
elif activation.lower() == 'softplus':
return nn.Softplus
else:
raise ValueError(f"Unknown activation type {activation}")
elif isinstance(activation, nn.Module):
return activation
else:
raise ValueError(f"Unknown activation type {activation}")
class BilinearConvTranspose2d(nn.ConvTranspose2d):
def __init__(self, in_channels, out_channels, scale, groups=1):
kernel_size = 2 * scale - scale % 2
self.scale = scale
super().__init__(
in_channels, out_channels,
kernel_size=kernel_size,
stride=scale,
padding=1,
groups=groups,
bias=False)
self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
class DistMaps(nn.Module):
def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
super(DistMaps, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
def get_coord_features(self, points, batchsize, rows, cols):
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = self.spatial_scale * self.norm_radius
coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
norm_delimeter))
coords = torch.from_numpy(np.
gitextract_qarvn3ca/
├── .gitignore
├── .gitmodules
├── README.md
├── XMem/
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── range_transform.py
│ │ ├── reseed.py
│ │ ├── static_dataset.py
│ │ ├── tps.py
│ │ ├── util.py
│ │ └── vos_dataset.py
│ ├── eval.py
│ ├── eval_batch.py
│ ├── generate_xmem_data_single.py
│ ├── inference/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── mask_mapper.py
│ │ │ ├── test_datasets.py
│ │ │ └── video_reader.py
│ │ ├── inference_core.py
│ │ ├── interact/
│ │ │ ├── __init__.py
│ │ │ ├── fbrs/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ ├── controller.py
│ │ │ │ ├── inference/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── clicker.py
│ │ │ │ │ ├── evaluation.py
│ │ │ │ │ ├── predictors/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── base.py
│ │ │ │ │ │ ├── brs.py
│ │ │ │ │ │ ├── brs_functors.py
│ │ │ │ │ │ └── brs_losses.py
│ │ │ │ │ ├── transforms/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── base.py
│ │ │ │ │ │ ├── crops.py
│ │ │ │ │ │ ├── flip.py
│ │ │ │ │ │ ├── limit_longest_side.py
│ │ │ │ │ │ └── zoom_in.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── model/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── initializer.py
│ │ │ │ │ ├── is_deeplab_model.py
│ │ │ │ │ ├── is_hrnet_model.py
│ │ │ │ │ ├── losses.py
│ │ │ │ │ ├── metrics.py
│ │ │ │ │ ├── modeling/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── basic_blocks.py
│ │ │ │ │ │ ├── deeplab_v3.py
│ │ │ │ │ │ ├── hrnet_ocr.py
│ │ │ │ │ │ ├── ocr.py
│ │ │ │ │ │ ├── resnet.py
│ │ │ │ │ │ └── resnetv1b.py
│ │ │ │ │ ├── ops.py
│ │ │ │ │ └── syncbn/
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── modules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── functional/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── _csrc.py
│ │ │ │ │ │ ├── csrc/
│ │ │ │ │ │ │ ├── bn.h
│ │ │ │ │ │ │ ├── cuda/
│ │ │ │ │ │ │ │ ├── bn_cuda.cu
│ │ │ │ │ │ │ │ ├── common.h
│ │ │ │ │ │ │ │ └── ext_lib.h
│ │ │ │ │ │ │ └── ext_lib.cpp
│ │ │ │ │ │ └── syncbn.py
│ │ │ │ │ └── nn/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── syncbn.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cython/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── _get_dist_maps.pyx
│ │ │ │ │ ├── _get_dist_maps.pyxbld
│ │ │ │ │ └── dist_maps.py
│ │ │ │ ├── misc.py
│ │ │ │ └── vis.py
│ │ │ ├── fbrs_controller.py
│ │ │ ├── gui.py
│ │ │ ├── gui_utils.py
│ │ │ ├── interaction.py
│ │ │ ├── interactive_utils.py
│ │ │ ├── resource_manager.py
│ │ │ ├── s2m/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _deeplab.py
│ │ │ │ ├── s2m_network.py
│ │ │ │ ├── s2m_resnet.py
│ │ │ │ └── utils.py
│ │ │ ├── s2m_controller.py
│ │ │ └── timer.py
│ │ ├── kv_memory_store.py
│ │ └── memory_manager.py
│ ├── interactive_demo.py
│ ├── merge_multi_scale.py
│ ├── merge_results.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── aggregate.py
│ │ ├── cbam.py
│ │ ├── group_modules.py
│ │ ├── losses.py
│ │ ├── memory_util.py
│ │ ├── modules.py
│ │ ├── network.py
│ │ ├── resnet.py
│ │ └── trainer.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── download_bl30k.py
│ │ ├── download_datasets.py
│ │ ├── download_models.sh
│ │ ├── download_models_demo.sh
│ │ ├── expand_long_vid.py
│ │ └── resize_youtube.py
│ ├── tracking.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── configuration.py
│ ├── davis_subset.txt
│ ├── image_saver.py
│ ├── load_subset.py
│ ├── log_integrator.py
│ ├── logger.py
│ ├── palette.py
│ ├── tensor_util.py
│ └── yv_subset.txt
├── merge_lora_weights_and_save_hf_model.py
├── model/
│ ├── VISA.py
│ ├── llava/
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── mm_utils.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── apply_delta.py
│ │ │ ├── builder.py
│ │ │ ├── consolidate.py
│ │ │ ├── language_model/
│ │ │ │ ├── llava_llama.py
│ │ │ │ ├── llava_mpt.py
│ │ │ │ └── mpt/
│ │ │ │ ├── adapt_tokenizer.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── blocks.py
│ │ │ │ ├── configuration_mpt.py
│ │ │ │ ├── custom_embedding.py
│ │ │ │ ├── flash_attn_triton.py
│ │ │ │ ├── hf_prefixlm_converter.py
│ │ │ │ ├── meta_init_context.py
│ │ │ │ ├── modeling_mpt.py
│ │ │ │ ├── norm.py
│ │ │ │ └── param_init_fns.py
│ │ │ ├── llava_arch.py
│ │ │ ├── make_delta.py
│ │ │ ├── multimodal_encoder/
│ │ │ │ ├── builder.py
│ │ │ │ └── clip_encoder.py
│ │ │ └── utils.py
│ │ ├── train/
│ │ │ ├── llama_flash_attn_monkey_patch.py
│ │ │ ├── llava_trainer.py
│ │ │ ├── train.py
│ │ │ └── train_mem.py
│ │ └── utils.py
│ ├── segment_anything/
│ │ ├── __init__.py
│ │ ├── automatic_mask_generator.py
│ │ ├── build_sam.py
│ │ ├── modeling/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── image_encoder.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── onnx.py
│ │ └── transforms.py
│ ├── tf/
│ │ └── modeling_outputs.py
│ └── univi/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── dataset_config.py
│ │ └── model_config.py
│ ├── constants.py
│ ├── conversation.py
│ ├── demo.py
│ ├── eval/
│ │ ├── evaluate/
│ │ │ ├── evaluate_benchmark_1_correctness.py
│ │ │ ├── evaluate_benchmark_2_detailed_orientation.py
│ │ │ ├── evaluate_benchmark_3_context.py
│ │ │ ├── evaluate_benchmark_4_temporal.py
│ │ │ ├── evaluate_benchmark_5_consistency.py
│ │ │ ├── evaluate_gpt_review_visual.py
│ │ │ ├── evaluate_science_qa.py
│ │ │ ├── evaluate_video_qa.py
│ │ │ └── summarize_gpt_review.py
│ │ ├── model_coco_vqa.py
│ │ ├── model_video_consistency.py
│ │ ├── model_video_general.py
│ │ ├── model_video_qa.py
│ │ ├── model_vqa.py
│ │ ├── model_vqa_scienceqa.py
│ │ ├── questions/
│ │ │ ├── coco2014_val_qa_eval/
│ │ │ │ ├── qa90_gpt4_answer.jsonl
│ │ │ │ └── qa90_questions.jsonl
│ │ │ ├── coco_pope/
│ │ │ │ ├── coco_pope_adversarial.jsonl
│ │ │ │ ├── coco_pope_popular.jsonl
│ │ │ │ └── coco_pope_random.jsonl
│ │ │ ├── scienceqa/
│ │ │ │ ├── pid_splits.json
│ │ │ │ ├── problems.json
│ │ │ │ └── test_QCM-LEA.json
│ │ │ └── video_qa/
│ │ │ ├── activitynet_a_list.json
│ │ │ ├── activitynet_qa.json
│ │ │ ├── consistency_qa.json
│ │ │ ├── generic_qa.json
│ │ │ ├── msrvtt_a_list.json
│ │ │ ├── msrvtt_qa.json
│ │ │ ├── msvd_a_list.json
│ │ │ ├── msvd_qa.json
│ │ │ ├── temporal_qa.json
│ │ │ ├── tgif_a_list.json
│ │ │ └── tgif_qa.json
│ │ └── table/
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── question.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── arch.py
│ │ ├── builder.py
│ │ ├── cluster.py
│ │ ├── consolidate.py
│ │ ├── dataloader.py
│ │ ├── language_model/
│ │ │ └── llama.py
│ │ ├── make_delta.py
│ │ └── multimodal_encoder/
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ ├── eva_encoder.py
│ │ ├── eva_vit.py
│ │ ├── processor.py
│ │ └── utils.py
│ ├── train/
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── train.py
│ │ ├── train_mem.py
│ │ └── trainer.py
│ └── utils.py
├── requirements.txt
├── scripts/
│ ├── train_13b.sh
│ ├── train_7b.sh
│ └── val_7b_video.sh
├── tools/
│ ├── eval_davis17.py
│ ├── eval_mevis.py
│ ├── eval_revos.py
│ ├── generate_foreground_mask.py
│ ├── metrics.py
│ ├── zip_mp_mevis.py
│ └── zip_mp_refytvos.py
├── train_ds.py
├── utils/
│ ├── ade20k_classes.json
│ ├── chatunivi_dataset.py
│ ├── cocostuff_classes.txt
│ ├── conversation.py
│ ├── d2_datasets/
│ │ ├── categories.py
│ │ ├── mevis_utils.py
│ │ ├── refytvos_utils.py
│ │ ├── refytvos_val_videos.py
│ │ └── ytvis_api/
│ │ ├── __init__.py
│ │ ├── ytvos.py
│ │ └── ytvoseval.py
│ ├── data_processing.py
│ ├── dataset.py
│ ├── dataset_config.py
│ ├── grefcoco.py
│ ├── grefer.py
│ ├── random_list.py
│ ├── reason_seg_dataset.py
│ ├── refer.py
│ ├── refer_seg_dataset.py
│ ├── rvos_dataset.py
│ ├── rvos_eval_dataset.py
│ ├── sem_seg_dataset.py
│ ├── utils.py
│ └── vqa_dataset.py
└── utils_llamavid/
├── llamavid_client.py
└── llamavid_server.py
SYMBOL INDEX (1459 symbols across 185 files)
FILE: XMem/dataset/reseed.py
function reseed (line 4) | def reseed(seed):
FILE: XMem/dataset/static_dataset.py
class StaticTransformDataset (line 16) | class StaticTransformDataset(Dataset):
method __init__ (line 24) | def __init__(self, parameters, num_frames=3, max_num_obj=1):
method _get_sample (line 90) | def _get_sample(self, idx):
method __getitem__ (line 128) | def __getitem__(self, idx):
method __len__ (line 178) | def __len__(self):
FILE: XMem/dataset/tps.py
function pick_random_points (line 8) | def pick_random_points(h, w, n_samples):
function warp_dual_cv (line 14) | def warp_dual_cv(img, mask, c_src, c_dst):
function random_tps_warp (line 22) | def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
FILE: XMem/dataset/util.py
function all_to_onehot (line 4) | def all_to_onehot(masks, labels):
FILE: XMem/dataset/vos_dataset.py
class VOSDataset (line 15) | class VOSDataset(Dataset):
method __init__ (line 25) | def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num...
method __getitem__ (line 97) | def __getitem__(self, idx):
method __len__ (line 215) | def __len__(self):
FILE: XMem/eval_batch.py
function run_eval (line 16) | def run_eval(meta_expression, temp_xmem_anno, final_xmem_anno, img_dir, ...
function main (line 23) | def main():
FILE: XMem/generate_xmem_data_single.py
function generate (line 20) | def generate(obj, temp_xmem_anno, final_xmem_anno):
function main (line 47) | def main():
FILE: XMem/inference/data/mask_mapper.py
class MaskMapper (line 7) | class MaskMapper:
method __init__ (line 19) | def __init__(self):
method convert_mask (line 26) | def convert_mask(self, mask, exhaustive=False):
method remap_index_mask (line 57) | def remap_index_mask(self, mask):
FILE: XMem/inference/data/test_datasets.py
class LongTestDataset (line 9) | class LongTestDataset:
method __init__ (line 11) | def __init__(self, meta_expression, data_root, size=-1, img_dir = '', ...
method get_datasets (line 34) | def get_datasets(self):
method __len__ (line 46) | def __len__(self):
class DAVISTestDataset (line 50) | class DAVISTestDataset:
method __init__ (line 51) | def __init__(self, data_root, imset='2017/val.txt', size=-1):
method get_datasets (line 69) | def get_datasets(self):
method __len__ (line 78) | def __len__(self):
class YouTubeVOSTestDataset (line 82) | class YouTubeVOSTestDataset:
method __init__ (line 83) | def __init__(self, data_root, split, size=480):
method get_datasets (line 104) | def get_datasets(self):
method __len__ (line 114) | def __len__(self):
FILE: XMem/inference/data/video_reader.py
class VideoReader (line 14) | class VideoReader(Dataset):
method __init__ (line 18) | def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=Non...
method __getitem__ (line 59) | def __getitem__(self, idx):
method resize_mask (line 92) | def resize_mask(self, mask):
method get_palette (line 99) | def get_palette(self):
method __len__ (line 102) | def __len__(self):
FILE: XMem/inference/inference_core.py
class InferenceCore (line 8) | class InferenceCore:
method __init__ (line 9) | def __init__(self, network:XMem, config):
method clear_memory (line 22) | def clear_memory(self):
method update_config (line 29) | def update_config(self, config):
method set_all_labels (line 38) | def set_all_labels(self, all_labels):
method step (line 42) | def step(self, image, mask=None, valid_labels=None, end=False):
FILE: XMem/inference/interact/fbrs/controller.py
class InteractiveController (line 11) | class InteractiveController:
method __init__ (line 12) | def __init__(self, net, device, predictor_params, prob_thresh=0.5):
method set_image (line 27) | def set_image(self, image):
method add_click (line 33) | def add_click(self, x, y, is_positive):
method undo_click (line 52) | def undo_click(self):
method partially_finish_object (line 61) | def partially_finish_object(self):
method finish_object (line 72) | def finish_object(self):
method reset_last_object (line 82) | def reset_last_object(self):
method reset_predictor (line 88) | def reset_predictor(self, predictor_params=None):
method current_object_prob (line 97) | def current_object_prob(self):
method is_incomplete_mask (line 105) | def is_incomplete_mask(self):
method result_mask (line 109) | def result_mask(self):
FILE: XMem/inference/interact/fbrs/inference/clicker.py
class Clicker (line 10) | class Clicker(object):
method __init__ (line 11) | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1):
method make_next_click (line 24) | def make_next_click(self, pred_mask):
method get_clicks (line 29) | def get_clicks(self, clicks_limit=None):
method _get_click (line 32) | def _get_click(self, pred_mask, padding=True):
method add_click (line 61) | def add_click(self, click):
method _remove_last_click (line 73) | def _remove_last_click(self):
method reset_clicks (line 85) | def reset_clicks(self):
method get_state (line 94) | def get_state(self):
method set_state (line 97) | def set_state(self, state):
method __len__ (line 102) | def __len__(self):
FILE: XMem/inference/interact/fbrs/inference/evaluation.py
function evaluate_dataset (line 16) | def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs):
function evaluate_sample (line 36) | def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr,
FILE: XMem/inference/interact/fbrs/inference/predictors/__init__.py
function get_predictor (line 8) | def get_predictor(net, brs_mode, device,
FILE: XMem/inference/interact/fbrs/inference/predictors/base.py
class BasePredictor (line 7) | class BasePredictor(object):
method __init__ (line 8) | def __init__(self, net, device,
method set_input_image (line 28) | def set_input_image(self, image_nd):
method get_prediction (line 35) | def get_prediction(self, clicker):
method _get_prediction (line 56) | def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
method _get_transform_states (line 60) | def _get_transform_states(self):
method _set_transform_states (line 63) | def _set_transform_states(self, states):
method apply_transforms (line 68) | def apply_transforms(self, image_nd, clicks_lists):
method get_points_nd (line 76) | def get_points_nd(self, clicks_lists):
method get_states (line 96) | def get_states(self):
method set_states (line 99) | def set_states(self, states):
FILE: XMem/inference/interact/fbrs/inference/predictors/brs.py
class BRSBasePredictor (line 10) | class BRSBasePredictor(BasePredictor):
method __init__ (line 11) | def __init__(self, model, device, opt_functor, optimize_after_n_clicks...
method set_input_image (line 19) | def set_input_image(self, image_nd):
method _get_clicks_maps_nd (line 24) | def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
method get_states (line 46) | def get_states(self):
method set_states (line 49) | def set_states(self, states):
class FeatureBRSPredictor (line 54) | class FeatureBRSPredictor(BRSBasePredictor):
method __init__ (line 55) | def __init__(self, model, device, opt_functor, insertion_mode='after_d...
method _get_prediction (line 69) | def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
method _get_head_input (line 121) | def _get_head_input(self, image_nd, points):
class HRNetFeatureBRSPredictor (line 143) | class HRNetFeatureBRSPredictor(BRSBasePredictor):
method __init__ (line 144) | def __init__(self, model, device, opt_functor, insertion_mode='A', **k...
method _get_prediction (line 156) | def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
method _get_head_input (line 209) | def _get_head_input(self, image_nd, points):
class InputBRSPredictor (line 228) | class InputBRSPredictor(BRSBasePredictor):
method __init__ (line 229) | def __init__(self, model, device, opt_functor, optimize_target='rgb', ...
method _get_prediction (line 233) | def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
FILE: XMem/inference/interact/fbrs/inference/predictors/brs_functors.py
class BaseOptimizer (line 8) | class BaseOptimizer:
method __init__ (line 9) | def __init__(self, optimizer_params,
method init_click (line 33) | def init_click(self, get_prediction_logits, pos_mask, neg_mask, device...
method __call__ (line 41) | def __call__(self, x):
method unpack_opt_params (line 79) | def unpack_opt_params(self, opt_params):
class InputOptimizer (line 83) | class InputOptimizer(BaseOptimizer):
method unpack_opt_params (line 84) | def unpack_opt_params(self, opt_params):
class ScaleBiasOptimizer (line 94) | class ScaleBiasOptimizer(BaseOptimizer):
method __init__ (line 95) | def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwar...
method unpack_opt_params (line 100) | def unpack_opt_params(self, opt_params):
FILE: XMem/inference/interact/fbrs/inference/predictors/brs_losses.py
class BRSMaskLoss (line 6) | class BRSMaskLoss(torch.nn.Module):
method __init__ (line 7) | def __init__(self, eps=1e-5):
method forward (line 11) | def forward(self, result, pos_mask, neg_mask):
class OracleMaskLoss (line 29) | class OracleMaskLoss(torch.nn.Module):
method __init__ (line 30) | def __init__(self):
method set_gt_mask (line 37) | def set_gt_mask(self, gt_mask):
method forward (line 41) | def forward(self, result, pos_mask, neg_mask):
FILE: XMem/inference/interact/fbrs/inference/transforms/base.py
class BaseTransform (line 4) | class BaseTransform(object):
method __init__ (line 5) | def __init__(self):
method transform (line 8) | def transform(self, image_nd, clicks_lists):
method inv_transform (line 11) | def inv_transform(self, prob_map):
method reset (line 14) | def reset(self):
method get_state (line 17) | def get_state(self):
method set_state (line 20) | def set_state(self, state):
class SigmoidForPred (line 24) | class SigmoidForPred(BaseTransform):
method transform (line 25) | def transform(self, image_nd, clicks_lists):
method inv_transform (line 28) | def inv_transform(self, prob_map):
method reset (line 31) | def reset(self):
method get_state (line 34) | def get_state(self):
method set_state (line 37) | def set_state(self, state):
FILE: XMem/inference/interact/fbrs/inference/transforms/crops.py
class Crops (line 10) | class Crops(BaseTransform):
method __init__ (line 11) | def __init__(self, crop_size=(320, 480), min_overlap=0.2):
method transform (line 20) | def transform(self, image_nd, clicks_lists):
method inv_transform (line 51) | def inv_transform(self, prob_map):
method get_state (line 67) | def get_state(self):
method set_state (line 70) | def set_state(self, state):
method reset (line 73) | def reset(self):
function get_offsets (line 79) | def get_offsets(length, crop_size, min_overlap_ratio=0.2):
FILE: XMem/inference/interact/fbrs/inference/transforms/flip.py
class AddHorizontalFlip (line 7) | class AddHorizontalFlip(BaseTransform):
method transform (line 8) | def transform(self, image_nd, clicks_lists):
method inv_transform (line 23) | def inv_transform(self, prob_map):
method get_state (line 30) | def get_state(self):
method set_state (line 33) | def set_state(self, state):
method reset (line 36) | def reset(self):
FILE: XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py
class LimitLongestSide (line 4) | class LimitLongestSide(ZoomIn):
method __init__ (line 5) | def __init__(self, max_size=800):
method transform (line 8) | def transform(self, image_nd, clicks_lists):
FILE: XMem/inference/interact/fbrs/inference/transforms/zoom_in.py
class ZoomIn (line 8) | class ZoomIn(BaseTransform):
method __init__ (line 9) | def __init__(self,
method transform (line 29) | def transform(self, image_nd, clicks_lists):
method inv_transform (line 65) | def inv_transform(self, prob_map):
method check_possible_recalculation (line 85) | def check_possible_recalculation(self):
method get_state (line 98) | def get_state(self):
method set_state (line 102) | def set_state(self, state):
method reset (line 105) | def reset(self):
method _transform_clicks (line 112) | def _transform_clicks(self, clicks_list):
function get_object_roi (line 127) | def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
function get_roi_image_nd (line 142) | def get_roi_image_nd(image_nd, object_roi, target_size):
function check_object_roi (line 163) | def check_object_roi(object_roi, clicks_list):
FILE: XMem/inference/interact/fbrs/inference/utils.py
function get_time_metrics (line 11) | def get_time_metrics(all_ious, elapsed_time):
function load_is_model (line 21) | def load_is_model(checkpoint, device, backbone='auto', **kwargs):
function load_hrnet_is_model (line 40) | def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, o...
function load_deeplab_is_model (line 67) | def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_c...
function get_iou (line 103) | def get_iou(gt_mask, pred_mask, ignore_label=-1):
function compute_noc_metric (line 113) | def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
function find_checkpoint (line 133) | def find_checkpoint(weights_folder, checkpoint_name):
function get_results_table (line 156) | def get_results_table(noc_list, over_max_list, brs_type, dataset_name, m...
FILE: XMem/inference/interact/fbrs/model/initializer.py
class Initializer (line 6) | class Initializer(object):
method __init__ (line 7) | def __init__(self, local_init=True, gamma=None):
method __call__ (line 11) | def __call__(self, m):
method _init_weight (line 31) | def _init_weight(self, data):
method _init_bias (line 34) | def _init_bias(self, data):
method _init_gamma (line 37) | def _init_gamma(self, data):
method _init_beta (line 43) | def _init_beta(self, data):
class Bilinear (line 47) | class Bilinear(Initializer):
method __init__ (line 48) | def __init__(self, scale, groups, in_channels, **kwargs):
method _init_weight (line 54) | def _init_weight(self, data):
method get_bilinear_kernel (line 67) | def get_bilinear_kernel(scale):
class XavierGluon (line 79) | class XavierGluon(Initializer):
method __init__ (line 80) | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3,...
method _init_weight (line 87) | def _init_weight(self, arr):
FILE: XMem/inference/interact/fbrs/model/is_deeplab_model.py
function get_deeplab_model (line 9) | def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=...
class DistMapsModel (line 30) | class DistMapsModel(nn.Module):
method __init__ (line 31) | def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d,...
method forward (line 50) | def forward(self, image, points):
method load_weights (line 68) | def load_weights(self, path_to_weights):
method get_trainable_params (line 74) | def get_trainable_params(self):
FILE: XMem/inference/interact/fbrs/model/is_hrnet_model.py
function get_hrnet_model (line 8) | def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260,
class DistMapsHRNetModel (line 24) | class DistMapsHRNetModel(nn.Module):
method __init__ (line 25) | def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_outp...
method forward (line 43) | def forward(self, image, points):
method load_weights (line 67) | def load_weights(self, path_to_weights):
method get_trainable_params (line 73) | def get_trainable_params(self):
FILE: XMem/inference/interact/fbrs/model/losses.py
class NormalizedFocalLossSigmoid (line 9) | class NormalizedFocalLossSigmoid(nn.Module):
method __init__ (line 10) | def __init__(self, axis=-1, alpha=0.25, gamma=2,
method forward (line 30) | def forward(self, pred, label, sample_weight=None):
method log_states (line 66) | def log_states(self, sw, name, global_step):
class FocalLoss (line 70) | class FocalLoss(nn.Module):
method __init__ (line 71) | def __init__(self, axis=-1, alpha=0.25, gamma=2,
method forward (line 88) | def forward(self, pred, label, sample_weight=None):
class SigmoidBinaryCrossEntropyLoss (line 113) | class SigmoidBinaryCrossEntropyLoss(nn.Module):
method __init__ (line 114) | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, igno...
method forward (line 121) | def forward(self, pred, label):
FILE: XMem/inference/interact/fbrs/model/metrics.py
class TrainMetric (line 7) | class TrainMetric(object):
method __init__ (line 8) | def __init__(self, pred_outputs, gt_outputs):
method update (line 12) | def update(self, *args, **kwargs):
method get_epoch_value (line 15) | def get_epoch_value(self):
method reset_epoch_stats (line 18) | def reset_epoch_stats(self):
method log_states (line 21) | def log_states(self, sw, tag_prefix, global_step):
method name (line 25) | def name(self):
class AdaptiveIoU (line 29) | class AdaptiveIoU(TrainMetric):
method __init__ (line 30) | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.9...
method update (line 44) | def update(self, pred, gt):
method get_epoch_value (line 67) | def get_epoch_value(self):
method reset_epoch_stats (line 73) | def reset_epoch_stats(self):
method log_states (line 77) | def log_states(self, sw, tag_prefix, global_step):
method iou_thresh (line 82) | def iou_thresh(self):
function _compute_iou (line 86) | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
FILE: XMem/inference/interact/fbrs/model/modeling/basic_blocks.py
class ConvHead (line 6) | class ConvHead(nn.Module):
method __init__ (line 7) | def __init__(self, out_channels, in_channels=32, num_layers=1,
method forward (line 23) | def forward(self, *inputs):
class SepConvHead (line 27) | class SepConvHead(nn.Module):
method __init__ (line 28) | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
method forward (line 51) | def forward(self, *inputs):
class SeparableConv2d (line 57) | class SeparableConv2d(nn.Module):
method __init__ (line 58) | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, d...
method forward (line 70) | def forward(self, x):
FILE: XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py
class DeepLabV3Plus (line 12) | class DeepLabV3Plus(nn.Module):
method __init__ (line 13) | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
method load_pretrained_weights (line 51) | def load_pretrained_weights(self):
method set_prediction_mode (line 64) | def set_prediction_mode(self):
method forward (line 68) | def forward(self, x):
class _SkipProject (line 84) | class _SkipProject(nn.Module):
method __init__ (line 85) | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
method forward (line 95) | def forward(self, x):
class _DeepLabHead (line 99) | class _DeepLabHead(nn.Module):
method __init__ (line 100) | def __init__(self, out_channels, in_channels, mid_channels=256, norm_l...
method forward (line 111) | def forward(self, x):
class _ASPP (line 115) | class _ASPP(nn.Module):
method __init__ (line 116) | def __init__(self, in_channels, atrous_rates, out_channels=256,
method forward (line 144) | def forward(self, x):
class _AsppPooling (line 150) | class _AsppPooling(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels, out_channels, norm_layer):
method forward (line 162) | def forward(self, x):
function _ASPPConv (line 167) | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
FILE: XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py
class HighResolutionModule (line 13) | class HighResolutionModule(nn.Module):
method __init__ (line 14) | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
method _check_branches (line 33) | def _check_branches(self, num_branches, num_blocks, num_inchannels, nu...
method _make_one_branch (line 49) | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
method _make_branches (line 74) | def _make_branches(self, num_branches, block, num_blocks, num_channels):
method _make_fuse_layers (line 83) | def _make_fuse_layers(self):
method get_num_inchannels (line 125) | def get_num_inchannels(self):
method forward (line 128) | def forward(self, x):
class HighResolutionNet (line 155) | class HighResolutionNet(nn.Module):
method __init__ (line 156) | def __init__(self, width, num_classes, ocr_width=256, small=False,
method _make_transition_layer (line 239) | def _make_transition_layer(
method _make_layer (line 274) | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
method _make_stage (line 292) | def _make_stage(self, block, num_inchannels,
method forward (line 318) | def forward(self, x):
method compute_hrnet_feats (line 329) | def compute_hrnet_feats(self, x):
method load_pretrained_weights (line 379) | def load_pretrained_weights(self, pretrained_path=''):
FILE: XMem/inference/interact/fbrs/model/modeling/ocr.py
class SpatialGather_Module (line 7) | class SpatialGather_Module(nn.Module):
method __init__ (line 14) | def __init__(self, cls_num=0, scale=1):
method forward (line 19) | def forward(self, feats, probs):
class SpatialOCR_Module (line 30) | class SpatialOCR_Module(nn.Module):
method __init__ (line 36) | def __init__(self,
method forward (line 55) | def forward(self, feats, proxy_feats):
class ObjectAttentionBlock2D (line 63) | class ObjectAttentionBlock2D(nn.Module):
method __init__ (line 77) | def __init__(self,
method forward (line 117) | def forward(self, x, proxy):
FILE: XMem/inference/interact/fbrs/model/modeling/resnet.py
class ResNetBackbone (line 5) | class ResNetBackbone(torch.nn.Module):
method __init__ (line 6) | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=...
method forward (line 29) | def forward(self, x):
FILE: XMem/inference/interact/fbrs/model/modeling/resnetv1b.py
class BasicBlockV1b (line 6) | class BasicBlockV1b(nn.Module):
method __init__ (line 9) | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=...
method forward (line 23) | def forward(self, x):
class BottleneckV1b (line 42) | class BottleneckV1b(nn.Module):
method __init__ (line 45) | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=...
method forward (line 62) | def forward(self, x):
class ResNetV1b (line 85) | class ResNetV1b(nn.Module):
method __init__ (line 114) | def __init__(self, block, layers, classes=1000, dilated=True, deep_ste...
method _make_layer (line 153) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
method forward (line 197) | def forward(self, x):
function _safe_state_dict_filtering (line 217) | def _safe_state_dict_filtering(orig_dict, model_dict_keys):
function resnet34_v1b (line 227) | def resnet34_v1b(pretrained=False, **kwargs):
function resnet50_v1s (line 240) | def resnet50_v1s(pretrained=False, **kwargs):
function resnet101_v1s (line 253) | def resnet101_v1s(pretrained=False, **kwargs):
function resnet152_v1s (line 266) | def resnet152_v1s(pretrained=False, **kwargs):
FILE: XMem/inference/interact/fbrs/model/ops.py
function select_activation_function (line 9) | def select_activation_function(activation):
class BilinearConvTranspose2d (line 23) | class BilinearConvTranspose2d(nn.ConvTranspose2d):
method __init__ (line 24) | def __init__(self, in_channels, out_channels, scale, groups=1):
class DistMaps (line 39) | class DistMaps(nn.Module):
method __init__ (line 40) | def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
method get_coord_features (line 46) | def get_coord_features(self, points, batchsize, rows, cols):
method forward (line 82) | def forward(self, x, coords):
FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py
function _load_C_extensions (line 27) | def _load_C_extensions():
FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h
function __device__ (line 29) | __device__ Pair() {}
function __device__ (line 30) | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
function __device__ (line 31) | __device__ Pair(T v) : v1(v), v2(v) {}
function __device__ (line 32) | __device__ Pair(int v) : v1(v), v2(v) {}
function getMSB (line 54) | int getMSB(int val) { return 31 - __clz(val); }
function getNumThreads (line 56) | static int getNumThreads(int nElem) {
FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp
function PYBIND11_MODULE (line 3) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py
function _count_samples (line 20) | def _count_samples(x):
class BatchNorm2dSyncFunc (line 28) | class BatchNorm2dSyncFunc(Function):
method forward (line 31) | def forward(ctx, x, weight, bias, running_mean, running_var,
method backward (line 95) | def backward(ctx, dz):
FILE: XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py
class _BatchNorm (line 26) | class _BatchNorm(nn.Module):
method __init__ (line 32) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
method reset_parameters (line 55) | def reset_parameters(self):
method _check_input_dim (line 63) | def _check_input_dim(self, input):
method forward (line 66) | def forward(self, input):
method extra_repr (line 77) | def extra_repr(self):
class BatchNorm2dNoSync (line 84) | class BatchNorm2dNoSync(_BatchNorm):
method _check_input_dim (line 89) | def _check_input_dim(self, input):
class BatchNorm2dSync (line 95) | class BatchNorm2dSync(BatchNorm2dNoSync):
method __init__ (line 100) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
method forward (line 113) | def forward(self, x):
method __repr__ (line 139) | def __repr__(self):
FILE: XMem/inference/interact/fbrs/utils/misc.py
function get_dims_with_exclusion (line 7) | def get_dims_with_exclusion(dim, exclude=None):
function get_unique_labels (line 15) | def get_unique_labels(mask):
function get_bbox_from_mask (line 19) | def get_bbox_from_mask(mask):
function expand_bbox (line 28) | def expand_bbox(bbox, expand_ratio, min_crop_size=None):
function clamp_bbox (line 46) | def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
function get_bbox_iou (line 51) | def get_bbox_iou(b1, b2):
function get_segments_iou (line 57) | def get_segments_iou(s1, s2):
FILE: XMem/inference/interact/fbrs/utils/vis.py
function visualize_instances (line 7) | def visualize_instances(imask, bg_color=255,
function get_palette (line 26) | def get_palette(num_cls):
function visualize_mask (line 43) | def visualize_mask(mask, num_cls):
function visualize_proposals (line 50) | def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_r...
function draw_probmap (line 60) | def draw_probmap(x):
function draw_points (line 64) | def draw_points(image, points, color, radius=3):
function draw_instance_map (line 72) | def draw_instance_map(x, palette=None):
function blend_mask (line 80) | def blend_mask(image, mask, alpha=0.6):
function get_boundaries (line 89) | def get_boundaries(instances_masks, boundaries_width=1):
function draw_with_blend_and_clicks (line 105) | def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=No...
FILE: XMem/inference/interact/fbrs_controller.py
class FBRSController (line 6) | class FBRSController:
method __init__ (line 7) | def __init__(self, checkpoint_path, device='cuda:0', max_size=800):
method unanchor (line 33) | def unanchor(self):
method interact (line 36) | def interact(self, image, x, y, is_positive):
method undo (line 48) | def undo(self):
FILE: XMem/inference/interact/gui.py
class App (line 48) | class App(QWidget):
method __init__ (line 49) | def __init__(self, net: XMem,
method resizeEvent (line 387) | def resizeEvent(self, event):
method console_push_text (line 390) | def console_push_text(self, text):
method interaction_radio_clicked (line 394) | def interaction_radio_clicked(self, event):
method load_current_image_mask (line 413) | def load_current_image_mask(self, no_mask=False):
method load_current_torch_image_mask (line 425) | def load_current_torch_image_mask(self, no_mask=False):
method compose_current_im (line 432) | def compose_current_im(self):
method update_interact_vis (line 436) | def update_interact_vis(self):
method update_minimap (line 457) | def update_minimap(self):
method update_current_image_fast (line 471) | def update_current_image_fast(self):
method show_current_frame (line 485) | def show_current_frame(self, fast=False):
method pixel_pos_to_image_pos (line 497) | def pixel_pos_to_image_pos(self, x, y):
method is_pos_out_of_bound (line 517) | def is_pos_out_of_bound(self, x, y):
method get_scaled_pos (line 529) | def get_scaled_pos(self, x, y):
method clear_visualization (line 537) | def clear_visualization(self):
method reset_this_interaction (line 541) | def reset_this_interaction(self):
method set_viz_mode (line 548) | def set_viz_mode(self):
method save_current_mask (line 552) | def save_current_mask(self):
method tl_slide (line 556) | def tl_slide(self):
method brush_slide (line 569) | def brush_slide(self):
method on_forward_propagation (line 579) | def on_forward_propagation(self):
method on_backward_propagation (line 589) | def on_backward_propagation(self):
method on_pause (line 599) | def on_pause(self):
method on_propagation (line 608) | def on_propagation(self):
method pause_propagation (line 647) | def pause_propagation(self):
method on_commit (line 650) | def on_commit(self):
method on_prev_frame (line 654) | def on_prev_frame(self):
method on_next_frame (line 659) | def on_next_frame(self):
method on_play_video_timer (line 664) | def on_play_video_timer(self):
method on_play_video (line 670) | def on_play_video(self):
method on_export_visualization (line 678) | def on_export_visualization(self):
method on_object_dial_change (line 697) | def on_object_dial_change(self):
method on_reset_mask (line 701) | def on_reset_mask(self):
method on_zoom_plus (line 710) | def on_zoom_plus(self):
method on_zoom_minus (line 715) | def on_zoom_minus(self):
method set_navi_enable (line 720) | def set_navi_enable(self, boolean):
method hit_number_key (line 729) | def hit_number_key(self, number):
method clear_brush (line 742) | def clear_brush(self):
method vis_brush (line 746) | def vis_brush(self, ex, ey):
method on_mouse_press (line 752) | def on_mouse_press(self, event):
method on_mouse_motion (line 803) | def on_mouse_motion(self, event):
method update_interacted_mask (line 818) | def update_interacted_mask(self):
method complete_interaction (line 825) | def complete_interaction(self):
method on_mouse_release (line 830) | def on_mouse_release(self, event):
method wheelEvent (line 856) | def wheelEvent(self, event):
method update_gpu_usage (line 865) | def update_gpu_usage(self):
method on_gpu_timer (line 884) | def on_gpu_timer(self):
method update_memory_size (line 887) | def update_memory_size(self):
method on_work_min_change (line 907) | def on_work_min_change(self):
method on_work_max_change (line 912) | def on_work_max_change(self):
method update_config (line 917) | def update_config(self):
method on_clear_memory (line 927) | def on_clear_memory(self):
method _open_file (line 936) | def _open_file(self, prompt):
method on_import_mask (line 941) | def on_import_mask(self):
method on_import_layer (line 969) | def on_import_layer(self):
method _try_load_layer (line 976) | def _try_load_layer(self, file_name):
method on_save_visualization_toggle (line 1000) | def on_save_visualization_toggle(self):
FILE: XMem/inference/interact/gui_utils.py
function create_parameter_box (line 5) | def create_parameter_box(min_val, max_val, text, step=1, callback=None):
function create_gauge (line 26) | def create_gauge(text):
function apply_to_all_children_widget (line 43) | def apply_to_all_children_widget(layout, func):
FILE: XMem/inference/interact/interaction.py
function aggregate_sbg (line 18) | def aggregate_sbg(prob, keep_bg=False, hard=False):
function aggregate_wbg (line 36) | def aggregate_wbg(prob, keep_bg=False, hard=False):
class Interaction (line 53) | class Interaction:
method __init__ (line 54) | def __init__(self, image, prev_mask, true_size, controller):
method predict (line 65) | def predict(self):
class FreeInteraction (line 69) | class FreeInteraction(Interaction):
method __init__ (line 70) | def __init__(self, image, prev_mask, true_size, num_objects):
method set_size (line 83) | def set_size(self, size):
method push_point (line 90) | def push_point(self, x, y, k, vis=None):
method end_path (line 123) | def end_path(self):
method predict (line 127) | def predict(self):
class ScribbleInteraction (line 134) | class ScribbleInteraction(Interaction):
method __init__ (line 135) | def __init__(self, image, prev_mask, true_size, controller, num_objects):
method push_point (line 153) | def push_point(self, x, y, k, vis=None):
method end_path (line 187) | def end_path(self):
method predict (line 191) | def predict(self):
class ClickInteraction (line 197) | class ClickInteraction(Interaction):
method __init__ (line 198) | def __init__(self, image, prev_mask, true_size, controller, tar_obj):
method push_point (line 215) | def push_point(self, x, y, neg, vis=None):
method predict (line 245) | def predict(self):
FILE: XMem/inference/interact/interactive_utils.py
function image_to_torch (line 10) | def image_to_torch(frame: np.ndarray, device='cuda'):
function torch_prob_to_numpy_mask (line 17) | def torch_prob_to_numpy_mask(prob):
function index_numpy_to_one_hot_torch (line 22) | def index_numpy_to_one_hot_torch(mask, num_classes):
function get_visualization (line 48) | def get_visualization(mode, image, mask, layer, target_object):
function get_visualization_torch (line 66) | def get_visualization_torch(mode, image, prob, layer, target_object):
function overlay_davis (line 84) | def overlay_davis(image, mask, alpha=0.5, fade=False):
function overlay_popup (line 97) | def overlay_popup(image, mask, target_object):
function overlay_layer (line 106) | def overlay_layer(image, mask, layer, target_object):
function overlay_davis_torch (line 118) | def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
function overlay_popup_torch (line 138) | def overlay_popup_torch(image, mask, target_object):
function overlay_layer_torch (line 159) | def overlay_layer_torch(image, prob, layer, target_object):
FILE: XMem/inference/interact/resource_manager.py
class LRU (line 18) | class LRU:
method __init__ (line 19) | def __init__(self, func, maxsize=128):
method __call__ (line 24) | def __call__(self, *args):
method invalidate (line 35) | def invalidate(self, key):
class ResourceManager (line 39) | class ResourceManager:
method __init__ (line 40) | def __init__(self, config):
method _extract_frames (line 103) | def _extract_frames(self, video):
method _copy_resize_frames (line 124) | def _copy_resize_frames(self, images):
method save_mask (line 141) | def save_mask(self, ti, mask):
method save_visualization (line 151) | def save_visualization(self, ti, image):
method _get_image_unbuffered (line 163) | def _get_image_unbuffered(self, ti):
method _get_mask_unbuffered (line 171) | def _get_mask_unbuffered(self, ti):
method read_external_image (line 183) | def read_external_image(self, file_name, size=None):
method invalidate (line 193) | def invalidate(self, ti):
method __len__ (line 197) | def __len__(self):
method h (line 201) | def h(self):
method w (line 205) | def w(self):
FILE: XMem/inference/interact/s2m/_deeplab.py
class DeepLabV3 (line 13) | class DeepLabV3(_SimpleSegmentationModel):
class DeepLabHeadV3Plus (line 30) | class DeepLabHeadV3Plus(nn.Module):
method __init__ (line 31) | def __init__(self, in_channels, low_level_channels, num_classes, aspp_...
method forward (line 49) | def forward(self, feature):
method _init_weight (line 55) | def _init_weight(self):
class DeepLabHead (line 63) | class DeepLabHead(nn.Module):
method __init__ (line 64) | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
method forward (line 76) | def forward(self, feature):
method _init_weight (line 79) | def _init_weight(self):
class AtrousSeparableConvolution (line 87) | class AtrousSeparableConvolution(nn.Module):
method __init__ (line 90) | def __init__(self, in_channels, out_channels, kernel_size,
method forward (line 102) | def forward(self, x):
method _init_weight (line 105) | def _init_weight(self):
class ASPPConv (line 113) | class ASPPConv(nn.Sequential):
method __init__ (line 114) | def __init__(self, in_channels, out_channels, dilation):
class ASPPPooling (line 122) | class ASPPPooling(nn.Sequential):
method __init__ (line 123) | def __init__(self, in_channels, out_channels):
method forward (line 130) | def forward(self, x):
class ASPP (line 135) | class ASPP(nn.Module):
method __init__ (line 136) | def __init__(self, in_channels, atrous_rates):
method forward (line 159) | def forward(self, x):
function convert_to_separable_conv (line 168) | def convert_to_separable_conv(module):
FILE: XMem/inference/interact/s2m/s2m_network.py
function _segm_resnet (line 7) | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretra...
function _load_model (line 34) | def _load_model(arch_type, backbone, num_classes, output_stride, pretrai...
function deeplabv3_resnet50 (line 44) | def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backb...
function deeplabv3plus_resnet50 (line 56) | def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_b...
FILE: XMem/inference/interact/s2m/s2m_resnet.py
function conv3x3 (line 17) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 23) | def conv1x1(in_planes, out_planes, stride=1):
class Bottleneck (line 28) | class Bottleneck(nn.Module):
method __init__ (line 31) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
method forward (line 48) | def forward(self, x):
class ResNet (line 71) | class ResNet(nn.Module):
method __init__ (line 73) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
method _make_layer (line 122) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
method forward (line 146) | def forward(self, x):
function _resnet (line 164) | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
function resnet50 (line 173) | def resnet50(pretrained=False, progress=True, **kwargs):
FILE: XMem/inference/interact/s2m/utils.py
class _SimpleSegmentationModel (line 9) | class _SimpleSegmentationModel(nn.Module):
method __init__ (line 10) | def __init__(self, backbone, classifier):
method forward (line 15) | def forward(self, x):
class IntermediateLayerGetter (line 23) | class IntermediateLayerGetter(nn.ModuleDict):
method __init__ (line 54) | def __init__(self, model, return_layers):
method forward (line 71) | def forward(self, x):
FILE: XMem/inference/interact/s2m_controller.py
class S2MController (line 8) | class S2MController:
method __init__ (line 15) | def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cud...
method interact (line 21) | def interact(self, image, prev_mask, scr_mask):
FILE: XMem/inference/interact/timer.py
class Timer (line 3) | class Timer:
method __init__ (line 4) | def __init__(self):
method start (line 8) | def start(self):
method pause (line 14) | def pause(self):
method count (line 19) | def count(self):
method format (line 27) | def format(self):
method __str__ (line 32) | def __str__(self):
FILE: XMem/inference/kv_memory_store.py
class KeyValueMemoryStore (line 4) | class KeyValueMemoryStore:
method __init__ (line 18) | def __init__(self, count_usage: bool):
method add (line 36) | def add(self, key, value, shrinkage, selection, objects: List[int]):
method update_usage (line 92) | def update_usage(self, usage):
method sieve_by_range (line 101) | def sieve_by_range(self, start: int, end: int, min_size: int):
method remove_obsolete_features (line 135) | def remove_obsolete_features(self, max_size: int):
method get_usage (line 158) | def get_usage(self):
method get_all_sliced (line 166) | def get_all_sliced(self, start: int, end: int):
method get_v_size (line 183) | def get_v_size(self, ni: int):
method engaged (line 186) | def engaged(self):
method size (line 190) | def size(self):
method num_groups (line 197) | def num_groups(self):
method key (line 201) | def key(self):
method value (line 205) | def value(self):
method shrinkage (line 209) | def shrinkage(self):
method selection (line 213) | def selection(self):
FILE: XMem/inference/memory_manager.py
class MemoryManager (line 8) | class MemoryManager:
method __init__ (line 12) | def __init__(self, config):
method update_config (line 38) | def update_config(self, config):
method _readout (line 53) | def _readout(self, affinity, v):
method match_memory (line 57) | def match_memory(self, query_key, selection):
method add_memory (line 152) | def add_memory(self, key, shrinkage, value, objects, selection=None):
method create_hidden_state (line 192) | def create_hidden_state(self, n, sample_key):
method set_hidden (line 205) | def set_hidden(self, hidden):
method get_hidden (line 208) | def get_hidden(self):
method compress_features (line 211) | def compress_features(self):
method consolidation (line 243) | def consolidation(self, candidate_key, candidate_shrinkage, candidate_...
FILE: XMem/merge_multi_scale.py
function search_options (line 19) | def search_options(options, name):
function process_vid (line 26) | def process_vid(vid):
FILE: XMem/merge_results.py
function merge (line 18) | def merge(obj):
FILE: XMem/model/aggregate.py
function aggregate (line 6) | def aggregate(prob, dim, return_logits=False):
FILE: XMem/model/cbam.py
class BasicConv (line 7) | class BasicConv(nn.Module):
method __init__ (line 8) | def __init__(self, in_planes, out_planes, kernel_size, stride=1, paddi...
method forward (line 13) | def forward(self, x):
class Flatten (line 17) | class Flatten(nn.Module):
method forward (line 18) | def forward(self, x):
class ChannelGate (line 21) | class ChannelGate(nn.Module):
method __init__ (line 22) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 32) | def forward(self, x):
class ChannelPool (line 50) | class ChannelPool(nn.Module):
method forward (line 51) | def forward(self, x):
class SpatialGate (line 54) | class SpatialGate(nn.Module):
method __init__ (line 55) | def __init__(self):
method forward (line 60) | def forward(self, x):
class CBAM (line 66) | class CBAM(nn.Module):
method __init__ (line 67) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 73) | def forward(self, x):
FILE: XMem/model/group_modules.py
function interpolate_groups (line 15) | def interpolate_groups(g, ratio, mode, align_corners):
function upsample_groups (line 22) | def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False):
function downsample_groups (line 25) | def downsample_groups(g, ratio=1/2, mode='area', align_corners=None):
class GConv2D (line 29) | class GConv2D(nn.Conv2d):
method forward (line 30) | def forward(self, g):
class GroupResBlock (line 36) | class GroupResBlock(nn.Module):
method __init__ (line 37) | def __init__(self, in_dim, out_dim):
method forward (line 48) | def forward(self, g):
class MainToGroupDistributor (line 58) | class MainToGroupDistributor(nn.Module):
method __init__ (line 59) | def __init__(self, x_transform=None, method='cat', reverse_order=False):
method forward (line 66) | def forward(self, x, g):
FILE: XMem/model/losses.py
function dice_loss (line 8) | def dice_loss(input_mask, cls_gt):
class BootstrappedCE (line 23) | class BootstrappedCE(nn.Module):
method __init__ (line 24) | def __init__(self, start_warm, end_warm, top_p=0.15):
method forward (line 31) | def forward(self, input, target, it):
class LossComputer (line 46) | class LossComputer:
method __init__ (line 47) | def __init__(self, config):
method compute (line 52) | def compute(self, data, num_objects, it):
FILE: XMem/model/memory_util.py
function get_similarity (line 7) | def get_similarity(mk, ms, qk, qe):
function do_softmax (line 41) | def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, ret...
function get_affinity (line 67) | def get_affinity(mk, ms, qk, qe):
function readout (line 73) | def readout(affinity, mv):
FILE: XMem/model/modules.py
class FeatureFusionBlock (line 22) | class FeatureFusionBlock(nn.Module):
method __init__ (line 23) | def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
method forward (line 31) | def forward(self, x, g):
class HiddenUpdater (line 44) | class HiddenUpdater(nn.Module):
method __init__ (line 46) | def __init__(self, g_dims, mid_dim, hidden_dim):
method forward (line 58) | def forward(self, g, h):
class HiddenReinforcer (line 77) | class HiddenReinforcer(nn.Module):
method __init__ (line 79) | def __init__(self, g_dim, hidden_dim):
method forward (line 86) | def forward(self, g, h):
class ValueEncoder (line 102) | class ValueEncoder(nn.Module):
method __init__ (line 103) | def __init__(self, value_dim, hidden_dim, single_object=False):
method forward (line 124) | def forward(self, image, image_feat_f16, h, masks, others, is_deep_upd...
class KeyEncoder (line 153) | class KeyEncoder(nn.Module):
method __init__ (line 154) | def __init__(self):
method forward (line 166) | def forward(self, f):
class UpsampleBlock (line 178) | class UpsampleBlock(nn.Module):
method __init__ (line 179) | def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
method forward (line 186) | def forward(self, skip_f, up_g):
class KeyProjection (line 194) | class KeyProjection(nn.Module):
method __init__ (line 195) | def __init__(self, in_dim, keydim):
method forward (line 207) | def forward(self, x, need_s, need_e):
class Decoder (line 214) | class Decoder(nn.Module):
method __init__ (line 215) | def __init__(self, val_dim, hidden_dim):
method forward (line 229) | def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
FILE: XMem/model/network.py
class XMem (line 19) | class XMem(nn.Module):
method __init__ (line 20) | def __init__(self, config, model_path=None, map_location=None):
method encode_key (line 42) | def encode_key(self, frame, need_sk=True, need_ek=True):
method encode_value (line 74) | def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_upda...
method read_memory (line 91) | def read_memory(self, query_key, query_selection, memory_key,
method segment (line 109) | def segment(self, multi_scale_features, memory_readout,
method forward (line 124) | def forward(self, mode, *args, **kwargs):
method init_hyperparameters (line 137) | def init_hyperparameters(self, config, model_path=None, map_location=N...
method load_weights (line 188) | def load_weights(self, src_dict, init_as_zero_if_needed=False):
FILE: XMem/model/resnet.py
function load_weights_add_extra_dim (line 14) | def load_weights_add_extra_dim(target, source_state, extra_dim=1):
function conv3x3 (line 41) | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
class BasicBlock (line 46) | class BasicBlock(nn.Module):
method __init__ (line 49) | def __init__(self, inplanes, planes, stride=1, downsample=None, dilati...
method forward (line 59) | def forward(self, x):
class Bottleneck (line 78) | class Bottleneck(nn.Module):
method __init__ (line 81) | def __init__(self, inplanes, planes, stride=1, downsample=None, dilati...
method forward (line 94) | def forward(self, x):
class ResNet (line 117) | class ResNet(nn.Module):
method __init__ (line 118) | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
method _make_layer (line 138) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
function resnet18 (line 154) | def resnet18(pretrained=True, extra_dim=0):
function resnet50 (line 160) | def resnet50(pretrained=True, extra_dim=0):
FILE: XMem/model/trainer.py
class XMemTrainer (line 20) | class XMemTrainer:
method __init__ (line 21) | def __init__(self, config, logger=None, save_path=None, local_rank=0, ...
method do_pass (line 56) | def do_pass(self, data, it=0):
method save_network (line 160) | def save_network(self, it):
method save_checkpoint (line 170) | def save_checkpoint(self, it):
method load_checkpoint (line 185) | def load_checkpoint(self, path):
method load_network_in_memory (line 204) | def load_network_in_memory(self, src_dict):
method load_network (line 208) | def load_network(self, path):
method train (line 216) | def train(self):
method val (line 223) | def val(self):
method test (line 229) | def test(self):
FILE: XMem/scripts/resize_youtube.py
function resize_vid_jpeg (line 12) | def resize_vid_jpeg(inputs):
function resize_vid_anno (line 28) | def resize_vid_anno(inputs):
function resize_all (line 45) | def resize_all(in_path, out_path):
FILE: XMem/tracking.py
function run_eval (line 22) | def run_eval(meta_expression, temp_xmem_anno, final_xmem_anno, img_dir, ...
function generate (line 29) | def generate(obj, temp_xmem_anno, final_xmem_anno):
function prepare (line 56) | def prepare(args):
function inference (line 90) | def inference(args):
function main (line 112) | def main():
FILE: XMem/train.py
function worker_init_fn (line 114) | def worker_init_fn(worker_id):
function construct_loader (line 119) | def construct_loader(dataset):
function renew_vos_loader (line 125) | def renew_vos_loader(max_skip, finetune=False):
function renew_bl_loader (line 140) | def renew_bl_loader(max_skip, finetune=False):
FILE: XMem/util/configuration.py
function none_or_default (line 4) | def none_or_default(x, default):
class Configuration (line 7) | class Configuration():
method parse (line 8) | def parse(self, unknown_arg_ok=False):
method get_stage_parameters (line 113) | def get_stage_parameters(self, stage):
method __getitem__ (line 128) | def __getitem__(self, key):
method __setitem__ (line 131) | def __setitem__(self, key, value):
method __str__ (line 134) | def __str__(self):
FILE: XMem/util/image_saver.py
function tensor_to_numpy (line 8) | def tensor_to_numpy(image):
function tensor_to_np_float (line 12) | def tensor_to_np_float(image):
function detach_to_cpu (line 16) | def detach_to_cpu(x):
function transpose_np (line 19) | def transpose_np(x):
function tensor_to_gray_im (line 22) | def tensor_to_gray_im(x):
function tensor_to_im (line 28) | def tensor_to_im(x):
function get_image_array (line 46) | def get_image_array(images, grid_shape, captions={}):
function base_transform (line 81) | def base_transform(im, size):
function im_transform (line 94) | def im_transform(im, size):
function mask_transform (line 97) | def mask_transform(mask, size):
function out_transform (line 100) | def out_transform(mask, size):
function pool_pairs (line 103) | def pool_pairs(images, size, num_objects):
FILE: XMem/util/load_subset.py
function load_sub_davis (line 8) | def load_sub_davis(path='util/davis_subset.txt'):
function load_sub_yv (line 13) | def load_sub_yv(path='util/yv_subset.txt'):
FILE: XMem/util/log_integrator.py
class Integrator (line 10) | class Integrator:
method __init__ (line 11) | def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
method add_tensor (line 22) | def add_tensor(self, key, tensor):
method add_dict (line 36) | def add_dict(self, tensor_dict):
method add_hook (line 40) | def add_hook(self, hook):
method reset_except_hooks (line 51) | def reset_except_hooks(self):
method finalize (line 56) | def finalize(self, prefix, it, f=None):
FILE: XMem/util/logger.py
function tensor_to_numpy (line 12) | def tensor_to_numpy(image):
function detach_to_cpu (line 16) | def detach_to_cpu(x):
function fix_width_trunc (line 19) | def fix_width_trunc(x):
class TensorboardLogger (line 22) | class TensorboardLogger:
method __init__ (line 23) | def __init__(self, short_id, id, git_info):
method log_scalar (line 47) | def log_scalar(self, tag, x, step):
method log_metrics (line 53) | def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
method log_im (line 62) | def log_im(self, tag, x, step):
method log_cv2 (line 71) | def log_cv2(self, tag, x, step):
method log_seg (line 78) | def log_seg(self, tag, x, step):
method log_gray (line 87) | def log_gray(self, tag, x, step):
method log_string (line 95) | def log_string(self, tag, x):
FILE: XMem/util/tensor_util.py
function compute_tensor_iu (line 4) | def compute_tensor_iu(seg, gt):
function compute_tensor_iou (line 10) | def compute_tensor_iou(seg, gt):
function pad_divide_by (line 17) | def pad_divide_by(in_img, d):
function unpad (line 34) | def unpad(img, pad):
FILE: merge_lora_weights_and_save_hf_model.py
function parse_args (line 24) | def parse_args(args):
function main (line 54) | def main(args):
FILE: model/VISA.py
function dice_loss (line 18) | def dice_loss(
function sigmoid_ce_loss (line 42) | def sigmoid_ce_loss(
class VisaMetaModel (line 60) | class VisaMetaModel:
method __init__ (line 61) | def __init__(
method initialize_lisa_modules (line 77) | def initialize_lisa_modules(self, config):
class VisaModel (line 102) | class VisaModel(VisaMetaModel, ChatUniViLlamaModel):
method __init__ (line 103) | def __init__(
class VISAForCausalLM (line 121) | class VISAForCausalLM(ChatUniViLlamaForCausalLM):
method __init__ (line 122) | def __init__(
method get_visual_embs (line 147) | def get_visual_embs(self, pixel_values: torch.FloatTensor):
method forward (line 152) | def forward(self, **kwargs):
method model_forward (line 157) | def model_forward(
method evaluate (line 334) | def evaluate(self, *args, **kwargs):
FILE: model/llava/conversation.py
class SeparatorStyle (line 6) | class SeparatorStyle(Enum):
class Conversation (line 17) | class Conversation:
method get_prompt (line 31) | def get_prompt(self):
method append_message (line 109) | def append_message(self, role, message):
method get_images (line 112) | def get_images(self, return_pil=False):
method to_gradio_chatbot (line 171) | def to_gradio_chatbot(self):
method copy (line 205) | def copy(self):
method dict (line 217) | def dict(self):
FILE: model/llava/mm_utils.py
function load_image_from_base64 (line 11) | def load_image_from_base64(image):
function process_images (line 15) | def process_images(images, image_processor, model_cfg):
function tokenizer_image_token (line 19) | def tokenizer_image_token(
function get_model_name_from_path (line 47) | def get_model_name_from_path(model_path):
class KeywordsStoppingCriteria (line 56) | class KeywordsStoppingCriteria(StoppingCriteria):
method __init__ (line 57) | def __init__(self, keywords, tokenizer, input_ids):
method __call__ (line 71) | def __call__(
FILE: model/llava/model/apply_delta.py
function apply_delta (line 13) | def apply_delta(base_model_path, target_model_path, delta_path):
FILE: model/llava/model/builder.py
function load_pretrained_model (line 27) | def load_pretrained_model(
FILE: model/llava/model/consolidate.py
function consolidate_ckpt (line 13) | def consolidate_ckpt(src_path, dst_path):
FILE: model/llava/model/language_model/llava_llama.py
class LlavaConfig (line 28) | class LlavaConfig(LlamaConfig):
class LlavaLlamaModel (line 32) | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
method __init__ (line 35) | def __init__(self, config: LlamaConfig):
class LlavaLlamaForCausalLM (line 39) | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
method __init__ (line 42) | def __init__(self, config):
method get_model (line 52) | def get_model(self):
method forward (line 55) | def forward(
method prepare_inputs_for_generation (line 137) | def prepare_inputs_for_generation(
FILE: model/llava/model/language_model/llava_mpt.py
class LlavaMPTConfig (line 29) | class LlavaMPTConfig(MPTConfig):
class LlavaMPTModel (line 33) | class LlavaMPTModel(LlavaMetaModel, MPTModel):
method __init__ (line 36) | def __init__(self, config: MPTConfig):
method embed_tokens (line 40) | def embed_tokens(self, x):
class LlavaMPTForCausalLM (line 44) | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
method __init__ (line 48) | def __init__(self, config):
method get_model (line 66) | def get_model(self):
method _set_gradient_checkpointing (line 69) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 73) | def forward(
method prepare_inputs_for_generation (line 138) | def prepare_inputs_for_generation(
FILE: model/llava/model/language_model/mpt/adapt_tokenizer.py
function adapt_tokenizer_for_denoising (line 10) | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
class AutoTokenizerForMOD (line 30) | class AutoTokenizerForMOD(AutoTokenizer):
method from_pretrained (line 42) | def from_pretrained(cls, *args, **kwargs):
FILE: model/llava/model/language_model/mpt/attention.py
function _reset_is_causal (line 15) | def _reset_is_causal(
function scaled_multihead_dot_product_attention (line 28) | def scaled_multihead_dot_product_attention(
function check_valid_inputs (line 103) | def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bflo...
function flash_attn_fn (line 115) | def flash_attn_fn(
function triton_flash_attn_fn (line 190) | def triton_flash_attn_fn(
class MultiheadAttention (line 261) | class MultiheadAttention(nn.Module):
method __init__ (line 268) | def __init__(
method forward (line 322) | def forward(
class MultiQueryAttention (line 357) | class MultiQueryAttention(nn.Module):
method __init__ (line 364) | def __init__(
method forward (line 419) | def forward(
function attn_bias_shape (line 457) | def attn_bias_shape(
function build_attn_bias (line 474) | def build_attn_bias(
function gen_slopes (line 497) | def gen_slopes(n_heads, alibi_bias_max=8, device=None):
function build_alibi_bias (line 507) | def build_alibi_bias(
FILE: model/llava/model/language_model/mpt/blocks.py
class MPTMLP (line 11) | class MPTMLP(nn.Module):
method __init__ (line 12) | def __init__(
method forward (line 21) | def forward(self, x):
class MPTBlock (line 25) | class MPTBlock(nn.Module):
method __init__ (line 26) | def __init__(
method forward (line 72) | def forward(
FILE: model/llava/model/language_model/mpt/configuration_mpt.py
class MPTConfig (line 30) | class MPTConfig(PretrainedConfig):
method __init__ (line 33) | def __init__(
method _set_config_defaults (line 134) | def _set_config_defaults(self, config, config_defaults):
method _validate_config (line 140) | def _validate_config(self):
FILE: model/llava/model/language_model/mpt/custom_embedding.py
class SharedEmbedding (line 7) | class SharedEmbedding(nn.Embedding):
method forward (line 8) | def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
FILE: model/llava/model/language_model/mpt/flash_attn_triton.py
function _fwd_kernel (line 59) | def _fwd_kernel(
function _bwd_preprocess_do_o_dot (line 271) | def _bwd_preprocess_do_o_dot(
function _bwd_store_dk_dv (line 317) | def _bwd_store_dk_dv(
function _bwd_kernel_one_col_block (line 350) | def _bwd_kernel_one_col_block(
function init_to_zero (line 574) | def init_to_zero(name):
function _bwd_kernel (line 609) | def _bwd_kernel(
function _flash_attn_forward (line 751) | def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=...
function _flash_attn_backward (line 833) | def _flash_attn_backward(
class FlashAttnQKVPackedFunc (line 938) | class FlashAttnQKVPackedFunc(torch.autograd.Function):
method forward (line 940) | def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
method backward (line 962) | def backward(ctx, do):
class FlashAttnKVPackedFunc (line 989) | class FlashAttnKVPackedFunc(torch.autograd.Function):
method forward (line 991) | def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
method backward (line 1013) | def backward(ctx, do):
class FlashAttnFunc (line 1042) | class FlashAttnFunc(torch.autograd.Function):
method forward (line 1044) | def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
method backward (line 1061) | def backward(ctx, do):
FILE: model/llava/model/language_model/mpt/hf_prefixlm_converter.py
function _convert_gpt_causal_lm_to_prefix_lm (line 45) | def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUS...
function _convert_bloom_causal_lm_to_prefix_lm (line 183) | def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> Bl...
function _convert_opt_causal_lm_to_prefix_lm (line 531) | def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTFor...
function convert_hf_causal_lm_to_prefix_lm (line 661) | def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_...
function add_bidirectional_mask_if_missing (line 732) | def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
FILE: model/llava/model/language_model/mpt/meta_init_context.py
function init_empty_weights (line 8) | def init_empty_weights(include_buffers: bool = False):
function init_on_device (line 40) | def init_on_device(device: torch.device, include_buffers: bool = False):
FILE: model/llava/model/language_model/mpt/modeling_mpt.py
class MPTPreTrainedModel (line 35) | class MPTPreTrainedModel(PreTrainedModel):
class MPTModel (line 41) | class MPTModel(MPTPreTrainedModel):
method __init__ (line 42) | def __init__(self, config: MPTConfig):
method get_input_embeddings (line 109) | def get_input_embeddings(self):
method set_input_embeddings (line 112) | def set_input_embeddings(self, value):
method _attn_bias (line 116) | def _attn_bias(
method _apply_prefix_mask (line 169) | def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: tor...
method _apply_sequence_id (line 192) | def _apply_sequence_id(
method forward (line 208) | def forward(
method param_init_fn (line 361) | def param_init_fn(self, module):
method fsdp_wrap_fn (line 370) | def fsdp_wrap_fn(self, module):
method activation_checkpointing_fn (line 373) | def activation_checkpointing_fn(self, module):
class MPTForCausalLM (line 377) | class MPTForCausalLM(MPTPreTrainedModel):
method __init__ (line 378) | def __init__(self, config: MPTConfig):
method get_input_embeddings (line 401) | def get_input_embeddings(self):
method set_input_embeddings (line 404) | def set_input_embeddings(self, value):
method get_output_embeddings (line 407) | def get_output_embeddings(self):
method set_output_embeddings (line 410) | def set_output_embeddings(self, new_embeddings):
method set_decoder (line 413) | def set_decoder(self, decoder):
method get_decoder (line 416) | def get_decoder(self):
method forward (line 419) | def forward(
method param_init_fn (line 476) | def param_init_fn(self, module):
method fsdp_wrap_fn (line 485) | def fsdp_wrap_fn(self, module):
method activation_checkpointing_fn (line 488) | def activation_checkpointing_fn(self, module):
method prepare_inputs_for_generation (line 491) | def prepare_inputs_for_generation(
method _reorder_cache (line 525) | def _reorder_cache(past_key_values, beam_idx):
FILE: model/llava/model/language_model/mpt/norm.py
function _cast_if_autocast_enabled (line 4) | def _cast_if_autocast_enabled(tensor):
class LPLayerNorm (line 16) | class LPLayerNorm(torch.nn.LayerNorm):
method __init__ (line 17) | def __init__(
method forward (line 33) | def forward(self, x):
function rms_norm (line 54) | def rms_norm(x, weight=None, eps=1e-05):
class RMSNorm (line 61) | class RMSNorm(torch.nn.Module):
method __init__ (line 62) | def __init__(
method forward (line 74) | def forward(self, x):
class LPRMSNorm (line 78) | class LPRMSNorm(RMSNorm):
method __init__ (line 79) | def __init__(
method forward (line 90) | def forward(self, x):
FILE: model/llava/model/language_model/mpt/param_init_fns.py
function torch_default_param_init_fn_ (line 13) | def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **...
function fused_init_helper_ (line 21) | def fused_init_helper_(module: nn.Module, init_fn_):
function generic_param_init_fn_ (line 33) | def generic_param_init_fn_(
function _normal_init_ (line 164) | def _normal_init_(std, mean=0.0):
function _normal_param_init_fn_ (line 168) | def _normal_param_init_fn_(
function baseline_param_init_fn_ (line 195) | def baseline_param_init_fn_(
function small_param_init_fn_ (line 223) | def small_param_init_fn_(
function neox_param_init_fn_ (line 247) | def neox_param_init_fn_(
function kaiming_uniform_param_init_fn_ (line 277) | def kaiming_uniform_param_init_fn_(
function kaiming_normal_param_init_fn_ (line 314) | def kaiming_normal_param_init_fn_(
function xavier_uniform_param_init_fn_ (line 351) | def xavier_uniform_param_init_fn_(
function xavier_normal_param_init_fn_ (line 381) | def xavier_normal_param_init_fn_(
FILE: model/llava/model/llava_arch.py
class LlavaMetaModel (line 29) | class LlavaMetaModel:
method __init__ (line 30) | def __init__(self, config):
method get_vision_tower (line 37) | def get_vision_tower(self):
method initialize_vision_modules (line 43) | def initialize_vision_modules(self, model_args, fsdp=None):
class LlavaMetaForCausalLM (line 85) | class LlavaMetaForCausalLM(ABC):
method get_model (line 87) | def get_model(self):
method get_vision_tower (line 90) | def get_vision_tower(self):
method encode_images (line 93) | def encode_images(self, images):
method prepare_inputs_labels_for_multimodal (line 98) | def prepare_inputs_labels_for_multimodal(
method initialize_vision_tokenizer (line 350) | def initialize_vision_tokenizer(self, model_args, num_new_tokens):
FILE: model/llava/model/make_delta.py
function make_delta (line 13) | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_...
FILE: model/llava/model/multimodal_encoder/builder.py
function build_vision_tower (line 4) | def build_vision_tower(vision_tower_cfg, **kwargs):
FILE: model/llava/model/multimodal_encoder/clip_encoder.py
class CLIPVisionTower (line 6) | class CLIPVisionTower(nn.Module):
method __init__ (line 7) | def __init__(self, vision_tower, args, delay_load=False):
method load_model (line 21) | def load_model(self):
method feature_select (line 31) | def feature_select(self, image_forward_outs):
method forward (line 42) | def forward(self, images):
method dummy_feature (line 63) | def dummy_feature(self):
method dtype (line 67) | def dtype(self):
method device (line 71) | def device(self):
method config (line 75) | def config(self):
method hidden_size (line 82) | def hidden_size(self):
method num_patches (line 86) | def num_patches(self):
FILE: model/llava/model/utils.py
function auto_upgrade (line 4) | def auto_upgrade(config):
FILE: model/llava/train/llama_flash_attn_monkey_patch.py
function forward (line 21) | def forward(
function _prepare_decoder_attention_mask (line 109) | def _prepare_decoder_attention_mask(
function replace_llama_attn_with_flash_attn (line 116) | def replace_llama_attn_with_flash_attn():
FILE: model/llava/train/llava_trainer.py
function maybe_zero_3 (line 8) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_mm_adapter_state_maybe_zero_3 (line 23) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
class LLaVATrainer (line 36) | class LLaVATrainer(Trainer):
method _save_checkpoint (line 37) | def _save_checkpoint(self, model, trial, metrics=None):
method _save (line 63) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
FILE: model/llava/train/train.py
function rank0_print (line 40) | def rank0_print(*args):
class ModelArguments (line 46) | class ModelArguments:
class DataArguments (line 62) | class DataArguments:
class TrainingArguments (line 74) | class TrainingArguments(transformers.TrainingArguments):
function maybe_zero_3 (line 107) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_peft_state_maybe_zero_3 (line 125) | def get_peft_state_maybe_zero_3(named_params, bias):
function get_peft_state_non_lora_maybe_zero_3 (line 150) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
function get_mm_adapter_state_maybe_zero_3 (line 160) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function find_all_linear_names (line 172) | def find_all_linear_names(model):
function safe_save_model_for_hf_trainer (line 185) | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output...
function smart_tokenizer_and_embedding_resize (line 227) | def smart_tokenizer_and_embedding_resize(
function _tokenize_fn (line 254) | def _tokenize_fn(
function _mask_targets (line 281) | def _mask_targets(target, tokenized_lens, speakers):
function _add_speaker_and_signal (line 292) | def _add_speaker_and_signal(header, source, get_conversation=True):
function preprocess_multimodal (line 314) | def preprocess_multimodal(sources: Sequence[str], data_args: DataArgumen...
function preprocess_llama_2 (line 344) | def preprocess_llama_2(
function preprocess_v1 (line 430) | def preprocess_v1(
function preprocess_mpt (line 516) | def preprocess_mpt(
function preprocess_plain (line 592) | def preprocess_plain(
function preprocess (line 621) | def preprocess(
class LazySupervisedDataset (line 681) | class LazySupervisedDataset(Dataset):
method __init__ (line 684) | def __init__(
method __len__ (line 698) | def __len__(self):
method __getitem__ (line 701) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
class DataCollatorForSupervisedDataset (line 764) | class DataCollatorForSupervisedDataset(object):
method __call__ (line 769) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
function make_supervised_data_module (line 797) | def make_supervised_data_module(
function train (line 810) | def train():
FILE: model/llava/utils.py
function build_logger (line 20) | def build_logger(logger_name, logger_filename):
class StreamToLogger (line 64) | class StreamToLogger(object):
method __init__ (line 69) | def __init__(self, logger, log_level=logging.INFO):
method __getattr__ (line 75) | def __getattr__(self, attr):
method write (line 78) | def write(self, buf):
method flush (line 92) | def flush(self):
function disable_torch_init (line 98) | def disable_torch_init():
function violates_moderation (line 108) | def violates_moderation(text):
function pretty_print_semaphore (line 131) | def pretty_print_semaphore(semaphore):
FILE: model/segment_anything/automatic_mask_generator.py
class SamAutomaticMaskGenerator (line 24) | class SamAutomaticMaskGenerator:
method __init__ (line 25) | def __init__(
method generate (line 127) | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
method _generate_masks (line 189) | def _generate_masks(self, image: np.ndarray) -> MaskData:
method _process_crop (line 217) | def _process_crop(
method _process_batch (line 260) | def _process_batch(
method postprocess_small_regions (line 324) | def postprocess_small_regions(
FILE: model/segment_anything/build_sam.py
function build_sam_vit_h (line 15) | def build_sam_vit_h(checkpoint=None):
function build_sam_vit_l (line 28) | def build_sam_vit_l(checkpoint=None):
function build_sam_vit_b (line 38) | def build_sam_vit_b(checkpoint=None):
function _build_sam (line 56) | def _build_sam(
FILE: model/segment_anything/modeling/common.py
class MLPBlock (line 13) | class MLPBlock(nn.Module):
method __init__ (line 14) | def __init__(
method forward (line 25) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm2d (line 31) | class LayerNorm2d(nn.Module):
method __init__ (line 32) | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
method forward (line 38) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: model/segment_anything/modeling/image_encoder.py
class ImageEncoderViT (line 17) | class ImageEncoderViT(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 110) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class Block (line 128) | class Block(nn.Module):
method __init__ (line 131) | def __init__(
method forward (line 177) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class Attention (line 196) | class Attention(nn.Module):
method __init__ (line 199) | def __init__(
method forward (line 235) | def forward(self, x: torch.Tensor) -> torch.Tensor:
function window_partition (line 263) | def window_partition(
function window_unpartition (line 291) | def window_unpartition(
function get_rel_pos (line 321) | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torc...
function add_decomposed_rel_pos (line 354) | def add_decomposed_rel_pos(
class PatchEmbed (line 395) | class PatchEmbed(nn.Module):
method __init__ (line 400) | def __init__(
method forward (line 422) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: model/segment_anything/modeling/mask_decoder.py
class MaskDecoder (line 16) | class MaskDecoder(nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 75) | def forward(
method predict_masks (line 116) | def predict_masks(
method forward_modified_v3 (line 167) | def forward_modified_v3(
class MLP (line 209) | class MLP(nn.Module):
method __init__ (line 210) | def __init__(
method forward (line 226) | def forward(self, x):
FILE: model/segment_anything/modeling/prompt_encoder.py
class PromptEncoder (line 16) | class PromptEncoder(nn.Module):
method __init__ (line 17) | def __init__(
method get_dense_pe (line 67) | def get_dense_pe(self) -> torch.Tensor:
method _embed_points (line 78) | def _embed_points(
method _embed_boxes (line 100) | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
method _embed_masks (line 111) | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
method _get_batch_size (line 116) | def _get_batch_size(
method _get_device (line 137) | def _get_device(self) -> torch.device:
method forward (line 140) | def forward(
class PositionEmbeddingRandom (line 189) | class PositionEmbeddingRandom(nn.Module):
method __init__ (line 194) | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = N...
method _pe_encoding (line 203) | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
method forward (line 216) | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
method forward_with_coords (line 231) | def forward_with_coords(
FILE: model/segment_anything/modeling/sam.py
class Sam (line 18) | class Sam(nn.Module):
method __init__ (line 22) | def __init__(
method device (line 52) | def device(self) -> Any:
method forward (line 56) | def forward(
method postprocess_masks (line 137) | def postprocess_masks(
method preprocess (line 174) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
FILE: model/segment_anything/modeling/transformer.py
class TwoWayTransformer (line 16) | class TwoWayTransformer(nn.Module):
method __init__ (line 17) | def __init__(
method forward (line 62) | def forward(
class TwoWayAttentionBlock (line 109) | class TwoWayAttentionBlock(nn.Module):
method __init__ (line 110) | def __init__(
method forward (line 151) | def forward(
class Attention (line 185) | class Attention(nn.Module):
method __init__ (line 191) | def __init__(
method _separate_heads (line 210) | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
method _recombine_heads (line 215) | def _recombine_heads(self, x: Tensor) -> Tensor:
method forward (line 220) | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
FILE: model/segment_anything/predictor.py
class SamPredictor (line 16) | class SamPredictor:
method __init__ (line 17) | def __init__(
method set_image (line 33) | def set_image(
method set_torch_image (line 64) | def set_torch_image(
method predict (line 93) | def predict(
method predict_torch (line 178) | def predict_torch(
method get_image_embedding (line 258) | def get_image_embedding(self) -> torch.Tensor:
method device (line 274) | def device(self) -> torch.device:
method reset_image (line 277) | def reset_image(self) -> None:
FILE: model/segment_anything/utils/amg.py
class MaskData (line 16) | class MaskData:
method __init__ (line 22) | def __init__(self, **kwargs) -> None:
method __setitem__ (line 29) | def __setitem__(self, key: str, item: Any) -> None:
method __delitem__ (line 35) | def __delitem__(self, key: str) -> None:
method __getitem__ (line 38) | def __getitem__(self, key: str) -> Any:
method items (line 41) | def items(self) -> ItemsView[str, Any]:
method filter (line 44) | def filter(self, keep: torch.Tensor) -> None:
method cat (line 59) | def cat(self, new_stats: "MaskData") -> None:
method to_numpy (line 72) | def to_numpy(self) -> None:
function is_box_near_crop_edge (line 78) | def is_box_near_crop_edge(
function box_xyxy_to_xywh (line 91) | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
function batch_iterator (line 98) | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None,...
function mask_to_rle_pytorch (line 107) | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
function rle_to_mask (line 138) | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
function area_from_rle (line 152) | def area_from_rle(rle: Dict[str, Any]) -> int:
function calculate_stability_score (line 156) | def calculate_stability_score(
function build_point_grid (line 179) | def build_point_grid(n_per_side: int) -> np.ndarray:
function build_all_layer_point_grids (line 189) | def build_all_layer_point_grids(
function generate_crop_boxes (line 200) | def generate_crop_boxes(
function uncrop_boxes_xyxy (line 237) | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch...
function uncrop_points (line 246) | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Te...
function uncrop_masks (line 255) | def uncrop_masks(
function remove_small_regions (line 267) | def remove_small_regions(
function coco_encode_rle (line 294) | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
function batched_mask_to_box (line 303) | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
FILE: model/segment_anything/utils/onnx.py
class SamOnnxModel (line 17) | class SamOnnxModel(nn.Module):
method __init__ (line 25) | def __init__(
method resize_longest_image_size (line 42) | def resize_longest_image_size(
method _embed_points (line 51) | def _embed_points(
method _embed_masks (line 74) | def _embed_masks(
method mask_postprocessing (line 85) | def mask_postprocessing(
method select_masks (line 105) | def select_masks(
method forward (line 121) | def forward(
FILE: model/segment_anything/utils/transforms.py
class ResizeLongestSide (line 17) | class ResizeLongestSide:
method __init__ (line 24) | def __init__(self, target_length: int) -> None:
method apply_image (line 27) | def apply_image(self, image: np.ndarray) -> np.ndarray:
method apply_coords (line 36) | def apply_coords(
method apply_boxes (line 52) | def apply_boxes(
method apply_image_torch (line 62) | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
method apply_coords_torch (line 76) | def apply_coords_torch(
method apply_boxes_torch (line 92) | def apply_boxes_torch(
method get_preprocess_shape (line 103) | def get_preprocess_shape(
FILE: model/tf/modeling_outputs.py
class CausalLMOutputWithPastAndLabel (line 8) | class CausalLMOutputWithPastAndLabel(ModelOutput):
FILE: model/univi/conversation.py
class SeparatorStyle (line 6) | class SeparatorStyle(Enum):
class Conversation (line 16) | class Conversation:
method get_prompt (line 29) | def get_prompt(self):
method append_message (line 106) | def append_message(self, role, message):
method get_images (line 109) | def get_images(self, return_pil=False):
method to_gradio_chatbot (line 158) | def to_gradio_chatbot(self):
method copy (line 191) | def copy(self):
method dict (line 202) | def dict(self):
FILE: model/univi/demo.py
class Chat (line 13) | class Chat:
method __init__ (line 14) | def __init__(self, model_path, conv_mode="simple"):
method get_prompt (line 34) | def get_prompt(self, qs, state):
method _get_rawvideo_dec (line 39) | def _get_rawvideo_dec(self, video_path, image_processor, max_frames=MA...
method generate (line 77) | def generate(self, images_tensor: list, prompt: str, first_run: bool, ...
FILE: model/univi/eval/evaluate/evaluate_benchmark_1_correctness.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 84) | def main():
FILE: model/univi/eval/evaluate/evaluate_benchmark_2_detailed_orientation.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 84) | def main():
FILE: model/univi/eval/evaluate/evaluate_benchmark_3_context.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 84) | def main():
FILE: model/univi/eval/evaluate/evaluate_benchmark_4_temporal.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 83) | def main():
FILE: model/univi/eval/evaluate/evaluate_benchmark_5_consistency.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 89) | def main():
FILE: model/univi/eval/evaluate/evaluate_gpt_review_visual.py
function get_eval (line 11) | def get_eval(content: str, max_tokens: int):
function parse_score (line 36) | def parse_score(review):
FILE: model/univi/eval/evaluate/evaluate_science_qa.py
function get_args (line 9) | def get_args():
function convert_caps (line 20) | def convert_caps(results):
function get_pred_idx (line 29) | def get_pred_idx(prediction, choices, options):
FILE: model/univi/eval/evaluate/evaluate_video_qa.py
function read_jsonl (line 10) | def read_jsonl(file):
function parse_args (line 18) | def parse_args():
function annotate (line 29) | def annotate(prediction_set, caption_files, output_dir):
function main (line 83) | def main():
FILE: model/univi/eval/evaluate/summarize_gpt_review.py
function parse_args (line 8) | def parse_args():
FILE: model/univi/eval/model_coco_vqa.py
function get_acc (line 19) | def get_acc(file):
function split_list (line 61) | def split_list(lst, n):
function get_chunk (line 67) | def get_chunk(lst, n, k):
class LogitsProcessor (line 72) | class LogitsProcessor(ABC):
method __call__ (line 74) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
function eval_model (line 81) | def eval_model(args):
FILE: model/univi/eval/model_video_consistency.py
function split_list (line 18) | def split_list(lst, n):
function get_chunk (line 24) | def get_chunk(lst, n, k):
function _get_rawvideo_dec (line 29) | def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_...
function eval_model (line 90) | def eval_model(args):
FILE: model/univi/eval/model_video_general.py
function split_list (line 18) | def split_list(lst, n):
function get_chunk (line 24) | def get_chunk(lst, n, k):
function _get_rawvideo_dec (line 29) | def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_...
function eval_model (line 90) | def eval_model(args):
FILE: model/univi/eval/model_video_qa.py
function read_json (line 18) | def read_json(file):
function split_list (line 23) | def split_list(lst, n):
function get_chunk (line 29) | def get_chunk(lst, n, k):
function _get_rawvideo_dec (line 34) | def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_...
function eval_model (line 95) | def eval_model(args):
FILE: model/univi/eval/model_vqa.py
function split_list (line 16) | def split_list(lst, n):
function get_chunk (line 22) | def get_chunk(lst, n, k):
function eval_model (line 27) | def eval_model(args):
FILE: model/univi/eval/model_vqa_scienceqa.py
function split_list (line 19) | def split_list(lst, n):
function get_chunk (line 25) | def get_chunk(lst, n, k):
function eval_model (line 30) | def eval_model(args):
FILE: model/univi/mm_utils.py
function load_image_from_base64 (line 10) | def load_image_from_base64(image):
function process_images (line 14) | def process_images(images, image_processor, model_cfg):
function tokenizer_image_token (line 18) | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOK...
function get_model_name_from_path (line 41) | def get_model_name_from_path(model_path):
class KeywordsStoppingCriteria (line 50) | class KeywordsStoppingCriteria(StoppingCriteria):
method __init__ (line 51) | def __init__(self, keywords, tokenizer, input_ids):
method __call__ (line 62) | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTe...
FILE: model/univi/model/apply_delta.py
function apply_delta (line 9) | def apply_delta(base_model_path, target_model_path, delta_path):
FILE: model/univi/model/arch.py
class MetaModel (line 10) | class MetaModel:
method __init__ (line 11) | def __init__(self, config):
method get_vision_tower (line 35) | def get_vision_tower(self):
method initialize_vision_modules (line 41) | def initialize_vision_modules(self, model_args, fsdp=None):
method initialize_cluster_modules (line 71) | def initialize_cluster_modules(self, model_args):
class ChatUniViMetaForCausalLM (line 88) | class ChatUniViMetaForCausalLM(ABC):
method get_model (line 90) | def get_model(self):
method get_vision_tower (line 93) | def get_vision_tower(self):
method encode_images (line 96) | def encode_images(self, images):
method positional_encoding (line 100) | def positional_encoding(self, x, num_features=1024, max_len=64):
method project (line 110) | def project(self, image_features, input_type="image"):
method prepare_inputs_labels_for_multimodal (line 219) | def prepare_inputs_labels_for_multimodal(
method initialize_vision_tokenizer (line 340) | def initialize_vision_tokenizer(self, model_args, tokenizer):
FILE: model/univi/model/builder.py
function load_pretrained_model (line 11) | def load_pretrained_model(model_path, model_base, model_name, load_8bit=...
FILE: model/univi/model/cluster.py
function _no_grad_trunc_normal_ (line 7) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
function trunc_normal_ (line 43) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
function drop_path (line 67) | def drop_path(x, drop_prob: float = 0., training: bool = False):
class DropPath (line 80) | class DropPath(nn.Module):
method __init__ (line 83) | def __init__(self, drop_prob=None):
method forward (line 87) | def forward(self, x):
function index_points (line 91) | def index_points(points, idx):
function cluster_dpc_knn (line 111) | def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
function merge_tokens (line 174) | def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
class CTM (line 226) | class CTM(nn.Module):
method __init__ (line 227) | def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
method forward (line 233) | def forward(self, token_dict, sample_ratio=None):
class TCBlock (line 259) | class TCBlock(nn.Module):
method __init__ (line 260) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_sca...
method _init_weights (line 265) | def _init_weights(self, m):
method forward (line 280) | def forward(self, inputs):
FILE: model/univi/model/consolidate.py
function consolidate_ckpt (line 13) | def consolidate_ckpt(src_path, dst_path):
FILE: model/univi/model/dataloader.py
function _get_rawvideo_dec (line 9) | def _get_rawvideo_dec(video_path, image_processor, max_frames=64, image_...
FILE: model/univi/model/language_model/llama.py
class ChatUniViConfig (line 12) | class ChatUniViConfig(LlamaConfig):
class ChatUniViLlamaModel (line 16) | class ChatUniViLlamaModel(MetaModel, LlamaModel):
method __init__ (line 19) | def __init__(self, config: LlamaConfig):
class ChatUniViLlamaForCausalLM (line 23) | class ChatUniViLlamaForCausalLM(LlamaForCausalLM, ChatUniViMetaForCausal...
method __init__ (line 26) | def __init__(self, config):
method get_model (line 33) | def get_model(self):
method forward (line 36) | def forward(
method prepare_inputs_for_generation (line 113) | def prepare_inputs_for_generation(
FILE: model/univi/model/make_delta.py
function make_delta (line 13) | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_...
FILE: model/univi/model/multimodal_encoder/builder.py
function build_vision_tower (line 5) | def build_vision_tower(vision_tower_cfg, **kwargs):
FILE: model/univi/model/multimodal_encoder/clip_encoder.py
class CLIPVisionTower (line 7) | class CLIPVisionTower(nn.Module):
method __init__ (line 8) | def __init__(self, vision_tower, args=None, delay_load=False):
method load_model (line 26) | def load_model(self):
method feature_select (line 34) | def feature_select(self, image_forward_outs, select_feature='patch'):
method forward (line 45) | def forward(self, images, select_feature='patch'):
method dummy_feature (line 59) | def dummy_feature(self):
method dtype (line 63) | def dtype(self):
method device (line 67) | def device(self):
method config (line 71) | def config(self):
method hidden_size (line 78) | def hidden_size(self):
method num_patches (line 82) | def num_patches(self):
FILE: model/univi/model/multimodal_encoder/eva_encoder.py
class EVAVisionTower (line 7) | class EVAVisionTower(nn.Module):
method __init__ (line 8) | def __init__(self, vision_tower, args, delay_load=False):
method load_model (line 22) | def load_model(self):
method feature_select (line 32) | def feature_select(self, image_forward_outs, select_feature='patch'):
method forward (line 43) | def forward(self, images, select_feature='patch'):
method dummy_feature (line 57) | def dummy_feature(self):
method dtype (line 61) | def dtype(self):
method device (line 65) | def device(self):
method config (line 69) | def config(self):
method hidden_size (line 76) | def hidden_size(self):
method num_patches (line 80) | def num_patches(self):
FILE: model/univi/model/multimodal_encoder/eva_vit.py
function _cfg (line 21) | def _cfg(url='', **kwargs):
class DropPath (line 31) | class DropPath(nn.Module):
method __init__ (line 35) | def __init__(self, drop_prob=None):
method forward (line 39) | def forward(self, x):
method extra_repr (line 42) | def extra_repr(self) -> str:
class Mlp (line 46) | class Mlp(nn.Module):
method __init__ (line 47) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 56) | def forward(self, x):
class Attention (line 66) | class Attention(nn.Module):
method __init__ (line 67) | def __init__(
method forward (line 120) | def forward(self, x, rel_pos_bias=None):
class Block (line 153) | class Block(nn.Module):
method __init__ (line 155) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 175) | def forward(self, x, rel_pos_bias=None):
class PatchEmbed (line 185) | class PatchEmbed(nn.Module):
method __init__ (line 189) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
method forward (line 201) | def forward(self, x, **kwargs):
class RelativePositionBias (line 210) | class RelativePositionBias(nn.Module):
method __init__ (line 212) | def __init__(self, window_size, num_heads):
method forward (line 241) | def forward(self):
class VisionTransformer (line 249) | class VisionTransformer(nn.Module):
method __init__ (line 253) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
method fix_init_weight (line 305) | def fix_init_weight(self):
method _init_weights (line 313) | def _init_weights(self, m):
method get_classifier (line 322) | def get_classifier(self):
method reset_classifier (line 325) | def reset_classifier(self, num_classes, global_pool=''):
method forward_features (line 329) | def forward_features(self, x):
method forward (line 355) | def forward(self, x):
method get_intermediate_layers (line 360) | def get_intermediate_layers(self, x):
function interpolate_pos_embed (line 379) | def interpolate_pos_embed(model, checkpoint_model):
function convert_weights_to_fp16 (line 403) | def convert_weights_to_fp16(model: nn.Module):
function create_eva_vit_g (line 421) | def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=Fa...
FILE: model/univi/model/multimodal_encoder/processor.py
class BaseProcessor (line 6) | class BaseProcessor:
method __init__ (line 7) | def __init__(self, mean=None, std=None):
class ImageTrainProcessor (line 16) | class ImageTrainProcessor(BaseProcessor):
method __init__ (line 17) | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5,...
method preprocess (line 30) | def preprocess(self, item, return_tensors):
class ImageEvalProcessor (line 34) | class ImageEvalProcessor(BaseProcessor):
method __init__ (line 35) | def __init__(self, image_size=224, mean=None, std=None):
method preprocess (line 48) | def preprocess(self, item, return_tensors):
class QWenImageProcessor (line 52) | class QWenImageProcessor(BaseProcessor):
method __init__ (line 53) | def __init__(self, image_size=224, mean=None, std=None):
method preprocess (line 67) | def preprocess(self, item, return_tensors):
FILE: model/univi/model/multimodal_encoder/utils.py
function setup_for_distributed (line 17) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 33) | def is_dist_avail_and_initialized():
function get_world_size (line 41) | def get_world_size():
function get_rank (line 47) | def get_rank():
function is_main_process (line 53) | def is_main_process():
function init_distributed_mode (line 57) | def init_distributed_mode(args):
function get_dist_info (line 93) | def get_dist_info():
function main_process (line 107) | def main_process(func):
function download_cached_file (line 117) | def download_cached_file(url, check_hash=True, progress=False):
FILE: model/univi/train/llama_flash_attn_monkey_patch.py
function forward (line 19) | def forward(
function _prepare_decoder_attention_mask (line 107) | def _prepare_decoder_attention_mask(
function replace_llama_attn_with_flash_attn (line 114) | def replace_llama_attn_with_flash_attn():
FILE: model/univi/train/train.py
function rank0_print (line 42) | def rank0_print(*args):
class ModelArguments (line 48) | class ModelArguments:
class DataArguments (line 64) | class DataArguments:
class TrainingArguments (line 73) | class TrainingArguments(transformers.TrainingArguments):
function maybe_zero_3 (line 108) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_peft_state_maybe_zero_3 (line 123) | def get_peft_state_maybe_zero_3(named_params, bias):
function get_peft_state_non_lora_maybe_zero_3 (line 148) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
function get_mm_adapter_state_maybe_zero_3 (line 156) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function find_all_linear_names (line 162) | def find_all_linear_names(model):
function safe_save_model_for_hf_trainer (line 176) | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
function smart_tokenizer_and_embedding_resize (line 214) | def smart_tokenizer_and_embedding_resize(
function _tokenize_fn (line 239) | def _tokenize_fn(strings: Sequence[str],
function _mask_targets (line 266) | def _mask_targets(target, tokenized_lens, speakers):
function _add_speaker_and_signal (line 277) | def _add_speaker_and_signal(header, source, get_conversation=True):
function preprocess_multimodal (line 298) | def preprocess_multimodal(
function preprocess_llama_2 (line 338) | def preprocess_llama_2(
function preprocess_v1 (line 426) | def preprocess_v1(
function preprocess_mpt (line 508) | def preprocess_mpt(
function preprocess_plain (line 574) | def preprocess_plain(
function preprocess (line 596) | def preprocess(
class LazySupervisedDataset (line 644) | class LazySupervisedDataset(Dataset):
method __init__ (line 647) | def __init__(self, tokenizer: transformers.PreTrainedTokenizer,
method __len__ (line 673) | def __len__(self):
method __getitem__ (line 676) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
class DataCollatorForSupervisedDataset (line 797) | class DataCollatorForSupervisedDataset(object):
method __call__ (line 802) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
function make_supervised_data_module (line 840) | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokeni...
function train (line 850) | def train():
FILE: model/univi/train/trainer.py
function maybe_zero_3 (line 7) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_mm_adapter_state_maybe_zero_3 (line 21) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
class ChatUniViTrainer (line 27) | class ChatUniViTrainer(Trainer):
method _save_checkpoint (line 28) | def _save_checkpoint(self, model, trial, metrics=None):
method _save (line 49) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
FILE: model/univi/utils.py
function build_logger (line 17) | def build_logger(logger_name, logger_filename):
class StreamToLogger (line 60) | class StreamToLogger(object):
method __init__ (line 64) | def __init__(self, logger, log_level=logging.INFO):
method __getattr__ (line 70) | def __getattr__(self, attr):
method write (line 73) | def write(self, buf):
method flush (line 87) | def flush(self):
function disable_torch_init (line 93) | def disable_torch_init():
function violates_moderation (line 102) | def violates_moderation(text):
function pretty_print_semaphore (line 123) | def pretty_print_semaphore(semaphore):
FILE: tools/eval_davis17.py
function eval_queue (line 23) | def eval_queue(q, rank, out_dict, mevis_pred_path):
function get_meta_exp (line 58) | def get_meta_exp(mevis_exp_path, ):
FILE: tools/eval_mevis.py
function eval_queue (line 22) | def eval_queue(q, rank, out_dict, mevis_pred_path):
FILE: tools/eval_revos.py
function eval_queue (line 26) | def eval_queue(q, rank, out_dict, visa_pred_path):
FILE: tools/generate_foreground_mask.py
function get_args (line 18) | def get_args():
function merge_rle (line 24) | def merge_rle(masks_rle_list: list, height: int, width: int):
function main (line 51) | def main():
FILE: tools/metrics.py
function get_r2vos_accuracy (line 6) | def get_r2vos_accuracy(gt_masks: List[np.ndarray], pred_masks: List[np.n...
function get_r2vos_robustness (line 23) | def get_r2vos_robustness(gt_masks: List[np.ndarray], pred_masks: List[np...
function db_eval_iou (line 43) | def db_eval_iou(annotation, segmentation, void_pixels=None):
function db_eval_boundary (line 77) | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_t...
function f_measure (line 94) | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
function _seg2bmap (line 158) | def _seg2bmap(seg, width=None, height=None):
FILE: tools/zip_mp_mevis.py
function zip_files (line 16) | def zip_files(path, temp_zip_file):
function main (line 24) | def main():
FILE: tools/zip_mp_refytvos.py
function zip_files (line 16) | def zip_files(path, temp_zip_file):
function main (line 24) | def main():
FILE: train_ds.py
function parse_args (line 29) | def parse_args(args):
function main (line 132) | def main(args):
function train (line 466) | def train(
function validate (line 601) | def validate(val_loader, model_engine, epoch, writer, args):
function rvos_validate (line 667) | def rvos_validate(val_loader, model_engine, epoch, writer, args):
FILE: utils/chatunivi_dataset.py
function _get_rawvideo_dec (line 24) | def _get_rawvideo_dec(video_path, image_processor, max_frames=64, image_...
function get_zero_image (line 69) | def get_zero_image(processor):
class ChatUniviDataset (line 75) | class ChatUniviDataset(torch.utils.data.Dataset):
method __init__ (line 87) | def __init__(
method load_data (line 128) | def load_data(self, dataset_name: str):
method __len__ (line 139) | def __len__(self):
method __getitem__ (line 142) | def __getitem__(self, i, max_try: int = 10):
method sample_data (line 173) | def sample_data(self, ):
FILE: utils/conversation.py
class SeparatorStyle (line 10) | class SeparatorStyle(Enum):
class Conversation (line 22) | class Conversation:
method get_prompt (line 48) | def get_prompt(self):
method append_message (line 109) | def append_message(self, role, message):
method to_gradio_chatbot (line 112) | def to_gradio_chatbot(self):
method copy (line 121) | def copy(self):
method dict (line 136) | def dict(self):
function get_default_conv_template (line 283) | def get_default_conv_template(model_name):
FILE: utils/d2_datasets/mevis_utils.py
function load_mevis_json (line 22) | def load_mevis_json(image_root, json_file, dataset_name, is_train: bool ...
FILE: utils/d2_datasets/refytvos_utils.py
function encode_anno_mask (line 32) | def encode_anno_mask(frames, vid_len, img_folder, video, obj_id, anno_id...
function load_refytvos_json (line 47) | def load_refytvos_json(img_folder: str, ann_file: str, dataset_name: str...
FILE: utils/d2_datasets/ytvis_api/ytvos.py
function _isArrayLike (line 41) | def _isArrayLike(obj):
class YTVOS (line 45) | class YTVOS:
method __init__ (line 46) | def __init__(self, annotation_file=None):
method createIndex (line 65) | def createIndex(self):
method info (line 96) | def info(self):
method getAnnIds (line 104) | def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None):
method getCatIds (line 132) | def getCatIds(self, catNms=[], supNms=[], catIds=[]):
method getVidIds (line 154) | def getVidIds(self, vidIds=[], catIds=[]):
method loadAnns (line 175) | def loadAnns(self, ids=[]):
method loadCats (line 186) | def loadCats(self, ids=[]):
method loadVids (line 197) | def loadVids(self, ids=[]):
method loadRes (line 209) | def loadRes(self, resFile):
method annToRLE (line 259) | def annToRLE(self, ann, frameId):
method annToMask (line 280) | def annToMask(self, ann, frameId):
FILE: utils/d2_datasets/ytvis_api/ytvoseval.py
class YTVOSeval (line 10) | class YTVOSeval:
method __init__ (line 60) | def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
method _prepare (line 85) | def _prepare(self):
method evaluate (line 129) | def evaluate(self):
method computeIoU (line 173) | def computeIoU(self, vidId, catId):
method computeOks (line 221) | def computeOks(self, imgId, catId):
method evaluateVid (line 264) | def evaluateVid(self, vidId, catId, aRng, maxDet):
method accumulate (line 344) | def accumulate(self, p = None):
method summarize (line 451) | def summarize(self):
method __str__ (line 524) | def __str__(self):
class Params (line 527) | class Params:
method setDetParams (line 531) | def setDetParams(self):
method setKpParams (line 544) | def setKpParams(self):
method __init__ (line 555) | def __init__(self, iouType='segm'):
FILE: utils/data_processing.py
function get_mask_from_json (line 9) | def get_mask_from_json(json_path, img):
FILE: utils/dataset.py
function collate_fn (line 36) | def collate_fn(
class HybridDataset (line 187) | class HybridDataset(torch.utils.data.Dataset):
method __init__ (line 193) | def __init__(
method __len__ (line 342) | def __len__(self):
method __getitem__ (line 345) | def __getitem__(self, idx):
class ValDataset (line 356) | class ValDataset(torch.utils.data.Dataset):
method __init__ (line 362) | def __init__(
method __len__ (line 421) | def __len__(self):
method preprocess (line 427) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method __getitem__ (line 439) | def __getitem__(self, idx):
FILE: utils/grefcoco.py
function load_grefcoco_json (line 25) | def load_grefcoco_json(
FILE: utils/grefer.py
class G_REFER (line 36) | class G_REFER:
method __init__ (line 37) | def __init__(self, data_root, dataset="grefcoco", splitBy="unc"):
method _toList (line 75) | def _toList(x):
method match_any (line 79) | def match_any(a, b):
method createIndex (line 84) | def createIndex(self):
method getRefIds (line 164) | def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
method getAnnIds (line 186) | def getAnnIds(self, image_ids=[], ref_ids=[]):
method getImgIds (line 210) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 219) | def getCatIds(self):
method loadRefs (line 222) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 225) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 230) | def loadImgs(self, image_ids=[]):
method loadCats (line 233) | def loadCats(self, cat_ids=[]):
method getRefBox (line 236) | def getRefBox(self, ref_id):
method showRef (line 240) | def showRef(self, ref, seg_box="seg"):
method getMask (line 302) | def getMask(self, ann):
method getMaskByRef (line 322) | def getMaskByRef(self, ref=None, ref_id=None, merge=False):
method showMask (line 348) | def showMask(self, ref):
FILE: utils/random_list.py
function lcg (line 3) | def lcg(modulus, a, c, seed):
function get_random_number (line 9) | def get_random_number(probabilities, values, generator):
function get_random_list (line 20) | def get_random_list(probabilities, values, length, seed: int = 0):
FILE: utils/reason_seg_dataset.py
class ReasonSegDataset (line 20) | class ReasonSegDataset(torch.utils.data.Dataset):
method __init__ (line 26) | def __init__(
method __len__ (line 93) | def __len__(self):
method preprocess (line 96) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method __getitem__ (line 108) | def __getitem__(self, idx):
FILE: utils/refer.py
class REFER (line 44) | class REFER:
method __init__ (line 45) | def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
method createIndex (line 82) | def createIndex(self):
method getRefIds (line 145) | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
method getAnnIds (line 180) | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
method getImgIds (line 206) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 215) | def getCatIds(self):
method loadRefs (line 218) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 224) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 230) | def loadImgs(self, image_ids=[]):
method loadCats (line 236) | def loadCats(self, cat_ids=[]):
method getRefBox (line 242) | def getRefBox(self, ref_id):
method showRef (line 247) | def showRef(self, ref, seg_box="seg"):
method getMask (line 309) | def getMask(self, ref):
method showMask (line 361) | def showMask(self, ref):
FILE: utils/refer_seg_dataset.py
class ReferSegDataset (line 19) | class ReferSegDataset(torch.utils.data.Dataset):
method __init__ (line 25) | def __init__(
method __len__ (line 105) | def __len__(self):
method preprocess (line 108) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method __getitem__ (line 120) | def __getitem__(self, idx):
FILE: utils/rvos_dataset.py
function get_zero_image (line 43) | def get_zero_image(processor):
class RVOSDataset (line 48) | class RVOSDataset(torch.utils.data.Dataset):
method __init__ (line 54) | def __init__(
method __len__ (line 119) | def __len__(self):
method __getitem__ (line 122) | def __getitem__(self, idx):
method preprocess (line 179) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method sample_data (line 192) | def sample_data(self,):
FILE: utils/rvos_eval_dataset.py
function get_zero_image (line 46) | def get_zero_image(processor):
class RVOSEvalDataset (line 50) | class RVOSEvalDataset(torch.utils.data.Dataset):
method __init__ (line 56) | def __init__(
method __len__ (line 91) | def __len__(self):
method load_data (line 94) | def load_data(self, ):
method __getitem__ (line 137) | def __getitem__(self, idx):
method preprocess (line 221) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
FILE: utils/sem_seg_dataset.py
function init_mapillary (line 20) | def init_mapillary(base_image_dir):
function init_ade20k (line 39) | def init_ade20k(base_image_dir):
function init_cocostuff (line 69) | def init_cocostuff(base_image_dir):
function init_paco_lvis (line 88) | def init_paco_lvis(base_image_dir):
function init_pascal_part (line 112) | def init_pascal_part(base_image_dir):
class SemSegDataset (line 127) | class SemSegDataset(torch.utils.data.Dataset):
method __init__ (line 133) | def __init__(
method __len__ (line 173) | def __len__(self):
method preprocess (line 176) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method __getitem__ (line 188) | def __getitem__(self, idx):
FILE: utils/utils.py
function convert2imagesplit (line 89) | def convert2imagesplit(sent: str, video_len: int) -> str:
class Summary (line 95) | class Summary(Enum):
class AverageMeter (line 102) | class AverageMeter(object):
method __init__ (line 105) | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
method reset (line 111) | def reset(self):
method update (line 117) | def update(self, val, n=1):
method all_reduce (line 123) | def all_reduce(self):
method __str__ (line 146) | def __str__(self):
method summary (line 150) | def summary(self):
function intersectionAndUnionGPU (line 166) | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
class ProgressMeter (line 181) | class ProgressMeter(object):
method __init__ (line 182) | def __init__(self, num_batches, meters, prefix=""):
method display (line 187) | def display(self, batch):
method display_summary (line 192) | def display_summary(self):
method _get_batch_fmtstr (line 197) | def _get_batch_fmtstr(self, num_batches):
function dict_to_cuda (line 203) | def dict_to_cuda(input_dict):
FILE: utils/vqa_dataset.py
function preprocess_multimodal (line 16) | def preprocess_multimodal(source, mm_use_im_start_end):
class VQADataset (line 32) | class VQADataset(torch.utils.data.Dataset):
method __init__ (line 38) | def __init__(
method __len__ (line 69) | def __len__(self):
method preprocess (line 72) | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
method __getitem__ (line 84) | def __getitem__(self, idx):
FILE: utils_llamavid/llamavid_client.py
function call (line 22) | def call(video_dir: str, question: str, ):
function call_batch (line 30) | def call_batch(params_list: List[Tuple[str, str]], ):
function main (line 65) | def main():
FILE: utils_llamavid/llamavid_server.py
class VideoFeatureExtractor (line 32) | class VideoFeatureExtractor:
method __init__ (line 34) | def __init__(
method __call__ (line 57) | def __call__(self, video_dir: str) -> dict:
class LLaMAVIDGenerator (line 84) | class LLaMAVIDGenerator:
method __init__ (line 89) | def __init__(
method __call__ (line 116) | def __call__(
class Inferencer (line 175) | class Inferencer:
method __init__ (line 176) | def __init__(
method __call__ (line 194) | def __call__(self, video_dir: str, question: str):
class InferenceServer (line 198) | class InferenceServer(Inferencer):
method __init__ (line 200) | def __init__(self, **kwargs):
method post (line 208) | def post(self):
function parse_args (line 222) | def parse_args():
function main (line 240) | def main():
Copy disabled (too large)
Download .json
Condensed preview — 276 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (28,094K chars).
[
{
"path": ".gitignore",
"chars": 115,
"preview": "**/__pycache__\nruns/\nmodels/\ndatasets/\ndatasets\n.vscode/\ncore*\nvis_output/\ntest_vis/\nopenai\n.DS_Store\nXMem/weights/"
},
{
"path": ".gitmodules",
"chars": 92,
"preview": "[submodule \"LLaMA-VID\"]\n\tpath = LLaMA-VID\n\turl = git@github.com:dvlab-research/LLaMA-VID.git"
},
{
"path": "README.md",
"chars": 8721,
"preview": "# VISA: Reasoning Video Object Segmentation via Large Language Model\n\n<font size=7><div align='center' >\n[![ GitHub star"
},
{
"path": "XMem/dataset/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/dataset/range_transform.py",
"chars": 377,
"preview": "import torchvision.transforms as transforms\n\nim_mean = (124, 116, 104)\n\nim_normalization = transforms.Normalize(\n "
},
{
"path": "XMem/dataset/reseed.py",
"chars": 95,
"preview": "import torch\nimport random\n\ndef reseed(seed):\n random.seed(seed)\n torch.manual_seed(seed)"
},
{
"path": "XMem/dataset/static_dataset.py",
"chars": 6505,
"preview": "import os\nfrom os import path\n\nimport torch\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import transfo"
},
{
"path": "XMem/dataset/tps.py",
"chars": 1167,
"preview": "import numpy as np\nfrom PIL import Image\nimport cv2\nimport thinplate as tps\n\ncv2.setNumThreads(0)\n\ndef pick_random_point"
},
{
"path": "XMem/dataset/util.py",
"chars": 388,
"preview": "import numpy as np\n\n\ndef all_to_onehot(masks, labels):\n if len(masks.shape) == 3:\n Ms = np.zeros((len(labels),"
},
{
"path": "XMem/dataset/vos_dataset.py",
"chars": 8461,
"preview": "import os\nfrom os import path, replace\n\nimport torch\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision impor"
},
{
"path": "XMem/eval.py",
"chars": 10402,
"preview": "import os\nfrom os import path\nfrom argparse import ArgumentParser\nimport shutil\n\nimport torch\nimport torch.nn.functional"
},
{
"path": "XMem/eval_batch.py",
"chars": 1724,
"preview": "import os\nimport time\nimport torch\nimport argparse\nimport multiprocessing as mp\nfrom termcolor import colored\nfrom datet"
},
{
"path": "XMem/generate_xmem_data_single.py",
"chars": 3317,
"preview": "import sys\nimport os\nimport os.path as osp\nimport glob\nimport cv2\nimport multiprocessing\nimport json\nimport argparse\nfro"
},
{
"path": "XMem/inference/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/data/mask_mapper.py",
"chars": 2155,
"preview": "import numpy as np\nimport torch\n\nfrom dataset.util import all_to_onehot\n\n\nclass MaskMapper:\n \"\"\"\n This class is us"
},
{
"path": "XMem/inference/data/test_datasets.py",
"chars": 4278,
"preview": "import os\nfrom os import path\nimport json\nimport glob\n\nfrom inference.data.video_reader import VideoReader\n\n\nclass LongT"
},
{
"path": "XMem/inference/data/video_reader.py",
"chars": 3575,
"preview": "import os\nfrom os import path\n\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import transforms\nfrom torc"
},
{
"path": "XMem/inference/inference_core.py",
"chars": 4917,
"preview": "from inference.memory_manager import MemoryManager\nfrom model.network import XMem\nfrom model.aggregate import aggregate\n"
},
{
"path": "XMem/inference/interact/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/LICENSE",
"chars": 16724,
"preview": "Mozilla Public License Version 2.0\n==================================\n\n1. Definitions\n--------------\n\n1.1. \"Contributor\""
},
{
"path": "XMem/inference/interact/fbrs/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/controller.py",
"chars": 3421,
"preview": "import torch\ntry:\n from torch import mps\nexcept:\n pass\n\nfrom ..fbrs.inference import clicker\nfrom ..fbrs.inference"
},
{
"path": "XMem/inference/interact/fbrs/inference/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/inference/clicker.py",
"chars": 3186,
"preview": "from collections import namedtuple\n\nimport numpy as np\nfrom copy import deepcopy\nfrom scipy.ndimage import distance_tran"
},
{
"path": "XMem/inference/interact/fbrs/inference/evaluation.py",
"chars": 1697,
"preview": "from time import time\n\nimport numpy as np\nimport torch\n\nfrom ..inference import utils\nfrom ..inference.clicker import Cl"
},
{
"path": "XMem/inference/interact/fbrs/inference/predictors/__init__.py",
"chars": 3514,
"preview": "from .base import BasePredictor\nfrom .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor\nfrom ."
},
{
"path": "XMem/inference/interact/fbrs/inference/predictors/base.py",
"chars": 3980,
"preview": "import torch\nimport torch.nn.functional as F\n\nfrom ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSi"
},
{
"path": "XMem/inference/interact/fbrs/inference/predictors/brs.py",
"chars": 13053,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.optimize import fmin_l_bfgs_b\n\nfrom .base imp"
},
{
"path": "XMem/inference/interact/fbrs/inference/predictors/brs_functors.py",
"chars": 4046,
"preview": "import torch\nimport numpy as np\n\nfrom ...model.metrics import _compute_iou\nfrom .brs_losses import BRSMaskLoss\n\n\nclass B"
},
{
"path": "XMem/inference/interact/fbrs/inference/predictors/brs_losses.py",
"chars": 1938,
"preview": "import torch\n\nfrom ...model.losses import SigmoidBinaryCrossEntropyLoss\n\n\nclass BRSMaskLoss(torch.nn.Module):\n def __"
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/__init__.py",
"chars": 171,
"preview": "from .base import SigmoidForPred\nfrom .flip import AddHorizontalFlip\nfrom .zoom_in import ZoomIn\nfrom .limit_longest_sid"
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/base.py",
"chars": 776,
"preview": "import torch\n\n\nclass BaseTransform(object):\n def __init__(self):\n self.image_changed = False\n\n def transfor"
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/crops.py",
"chars": 3261,
"preview": "import math\n\nimport torch\nimport numpy as np\n\nfrom ...inference.clicker import Click\nfrom .base import BaseTransform\n\n\nc"
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/flip.py",
"chars": 1227,
"preview": "import torch\n\nfrom ..clicker import Click\nfrom .base import BaseTransform\n\n\nclass AddHorizontalFlip(BaseTransform):\n "
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py",
"chars": 824,
"preview": "from .zoom_in import ZoomIn, get_roi_image_nd\n\n\nclass LimitLongestSide(ZoomIn):\n def __init__(self, max_size=800):\n "
},
{
"path": "XMem/inference/interact/fbrs/inference/transforms/zoom_in.py",
"chars": 6401,
"preview": "import torch\n\nfrom ..clicker import Click\nfrom ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp"
},
{
"path": "XMem/inference/interact/fbrs/inference/utils.py",
"chars": 6668,
"preview": "from datetime import timedelta\nfrom pathlib import Path\n\nimport torch\nimport numpy as np\n\nfrom ..model.is_deeplab_model "
},
{
"path": "XMem/inference/interact/fbrs/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/model/initializer.py",
"chars": 3408,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\n\n\nclass Initializer(object):\n def __init__(self, local_init=Tru"
},
{
"path": "XMem/inference/interact/fbrs/model/is_deeplab_model.py",
"chars": 3316,
"preview": "import torch\nimport torch.nn as nn\n\nfrom .ops import DistMaps\nfrom .modeling.deeplab_v3 import DeepLabV3Plus\nfrom .model"
},
{
"path": "XMem/inference/interact/fbrs/model/is_hrnet_model.py",
"chars": 3597,
"preview": "import torch\nimport torch.nn as nn\n\nfrom .ops import DistMaps\nfrom .modeling.hrnet_ocr import HighResolutionNet\n\n\ndef ge"
},
{
"path": "XMem/inference/interact/fbrs/model/losses.py",
"chars": 5285,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ..utils import misc\n\n\nclass "
},
{
"path": "XMem/inference/interact/fbrs/model/metrics.py",
"chars": 3480,
"preview": "import torch\nimport numpy as np\n\nfrom ..utils import misc\n\n\nclass TrainMetric(object):\n def __init__(self, pred_outpu"
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/basic_blocks.py",
"chars": 2612,
"preview": "import torch.nn as nn\n\nfrom ...model import ops\n\n\nclass ConvHead(nn.Module):\n def __init__(self, out_channels, in_cha"
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py",
"chars": 6313,
"preview": "from contextlib import ExitStack\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .basic_blocks "
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py",
"chars": 17301,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\nfrom"
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/ocr.py",
"chars": 5727,
"preview": "import torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\n\n\nclass SpatialGather_Module(nn.M"
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/resnet.py",
"chars": 1460,
"preview": "import torch\nfrom .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s\n\n\nclass ResNetBackbone(torc"
},
{
"path": "XMem/inference/interact/fbrs/model/modeling/resnetv1b.py",
"chars": 10805,
"preview": "import torch\nimport torch.nn as nn\nGLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'\n\n\nclass BasicBloc"
},
{
"path": "XMem/inference/interact/fbrs/model/ops.py",
"chars": 3119,
"preview": "import torch\nfrom torch import nn as nn\nimport numpy as np\n\nfrom . import initializer as initializer\nfrom ..utils.cython"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/LICENSE",
"chars": 1070,
"preview": "MIT License\n\nCopyright (c) 2018 Tamaki Kojima\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/README.md",
"chars": 5334,
"preview": "# pytorch-syncbn\n\nTamaki Kojima(tamakoji@gmail.com)\n\n## Announcement\n\n**Pytorch 1.0 support**\n\n## Overview\nThis is alter"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py",
"chars": 37,
"preview": "from .syncbn import batchnorm2d_sync\n"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py",
"chars": 1586,
"preview": "\"\"\"\n/*****************************************************************************/\n\nExtension module loader\n\ncode refer"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h",
"chars": 2060,
"preview": "/*****************************************************************************\n\nSyncBN\n\n********************************"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu",
"chars": 9132,
"preview": "/*****************************************************************************\n\nCUDA SyncBN code\n\ncode referenced from :"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h",
"chars": 3268,
"preview": "/*****************************************************************************\n\nCUDA utility funcs\n\ncode referenced from"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h",
"chars": 1190,
"preview": "/*****************************************************************************\n\nCUDA SyncBN code\n\n**********************"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp",
"chars": 421,
"preview": "#include \"bn.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"syncbn_sum_sqsum\", &syncbn_sum_sqsum, \"Sum and Sum^"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py",
"chars": 5291,
"preview": "\"\"\"\n/*****************************************************************************/\n\nBatchNorm2dSync with multi-gpu\n\ncod"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py",
"chars": 22,
"preview": "from .syncbn import *\n"
},
{
"path": "XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py",
"chars": 5187,
"preview": "\"\"\"\n/*****************************************************************************/\n\nBatchNorm2dSync with multi-gpu\n\n/**"
},
{
"path": "XMem/inference/interact/fbrs/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/fbrs/utils/cython/__init__.py",
"chars": 74,
"preview": "# noinspection PyUnresolvedReferences\nfrom .dist_maps import get_dist_maps"
},
{
"path": "XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx",
"chars": 1939,
"preview": "import numpy as np\ncimport cython\ncimport numpy as np\nfrom libc.stdlib cimport malloc, free\n\nctypedef struct qnode:\n "
},
{
"path": "XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld",
"chars": 249,
"preview": "import numpy\n\ndef make_ext(modname, pyxfilename):\n from distutils.extension import Extension\n return Extension(mod"
},
{
"path": "XMem/inference/interact/fbrs/utils/cython/dist_maps.py",
"chars": 149,
"preview": "import pyximport; pyximport.install(pyximport=True, language_level=3)\n# noinspection PyUnresolvedReferences\nfrom ._get_d"
},
{
"path": "XMem/inference/interact/fbrs/utils/misc.py",
"chars": 1608,
"preview": "from functools import partial\n\nimport torch\nimport numpy as np\n\n\ndef get_dims_with_exclusion(dim, exclude=None):\n dim"
},
{
"path": "XMem/inference/interact/fbrs/utils/vis.py",
"chars": 4040,
"preview": "from functools import lru_cache\n\nimport cv2\nimport numpy as np\n\n\ndef visualize_instances(imask, bg_color=255,\n "
},
{
"path": "XMem/inference/interact/fbrs_controller.py",
"chars": 1739,
"preview": "import torch\nfrom .fbrs.controller import InteractiveController\nfrom .fbrs.inference import utils\n\n\nclass FBRSController"
},
{
"path": "XMem/inference/interact/gui.py",
"chars": 40465,
"preview": "\"\"\"\nBased on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN \n(which is based on https://github.com/seoungwugoh/ivs-"
},
{
"path": "XMem/inference/interact/gui_utils.py",
"chars": 1168,
"preview": "from PySide6.QtCore import Qt\nfrom PySide6.QtWidgets import (QBoxLayout, QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QPr"
},
{
"path": "XMem/inference/interact/interaction.py",
"chars": 8812,
"preview": "\"\"\"\nContains all the types of interaction related to the GUI\nNot related to automatic evaluation in the DAVIS dataset\n\nY"
},
{
"path": "XMem/inference/interact/interactive_utils.py",
"chars": 6934,
"preview": "# Modifed from https://github.com/seoungwugoh/ivs-demo\n\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F"
},
{
"path": "XMem/inference/interact/resource_manager.py",
"chars": 7000,
"preview": "import os\nfrom os import path\nimport shutil\nimport collections\n\nimport cv2\nfrom PIL import Image\nif not hasattr(Image, '"
},
{
"path": "XMem/inference/interact/s2m/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/inference/interact/s2m/_deeplab.py",
"chars": 6767,
"preview": "# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functio"
},
{
"path": "XMem/inference/interact/s2m/s2m_network.py",
"chars": 2486,
"preview": "# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch\n\nfrom .utils import IntermediateLayerGetter\nfrom ._deeplab impo"
},
{
"path": "XMem/inference/interact/s2m/s2m_resnet.py",
"chars": 6871,
"preview": "import torch\nimport torch.nn as nn\ntry:\n from torchvision.models.utils import load_state_dict_from_url\nexcept ModuleN"
},
{
"path": "XMem/inference/interact/s2m/utils.py",
"chars": 3048,
"preview": "# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport t"
},
{
"path": "XMem/inference/interact/s2m_controller.py",
"chars": 1516,
"preview": "import torch\nimport numpy as np\nfrom ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M\n\nfrom util.tensor_u"
},
{
"path": "XMem/inference/interact/timer.py",
"chars": 776,
"preview": "import time\n\nclass Timer:\n def __init__(self):\n self._acc_time = 0\n self._paused = True\n\n def start("
},
{
"path": "XMem/inference/kv_memory_store.py",
"chars": 8307,
"preview": "import torch\nfrom typing import List\n\nclass KeyValueMemoryStore:\n \"\"\"\n Works for key/value pairs type storage\n "
},
{
"path": "XMem/inference/memory_manager.py",
"chars": 12244,
"preview": "import torch\nimport warnings\n\nfrom inference.kv_memory_store import KeyValueMemoryStore\nfrom model.memory_util import *\n"
},
{
"path": "XMem/interactive_demo.py",
"chars": 4932,
"preview": "\"\"\"\nA simple user interface for XMem\n\"\"\"\n\nimport os\nfrom os import path\n# fix for Windows\nif 'QT_QPA_PLATFORM_PLUGIN_PAT"
},
{
"path": "XMem/merge_multi_scale.py",
"chars": 4143,
"preview": "import os\nfrom os import path\nfrom argparse import ArgumentParser\nimport glob\nfrom collections import defaultdict\n\nimpor"
},
{
"path": "XMem/merge_results.py",
"chars": 1438,
"preview": "import glob\nimport os\nfrom PIL import Image\nimport numpy as np\nimport tqdm\nimport multiprocessing\n\nmulti_dir = \"mevis_va"
},
{
"path": "XMem/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/model/aggregate.py",
"chars": 412,
"preview": "import torch\nimport torch.nn.functional as F\n\n\n# Soft aggregation from STM\ndef aggregate(prob, dim, return_logits=False)"
},
{
"path": "XMem/model/cbam.py",
"chars": 3042,
"preview": "# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py\n\nimport torch\nimport torch.nn as"
},
{
"path": "XMem/model/group_modules.py",
"chars": 2560,
"preview": "\"\"\"\nGroup-specific modules\nThey handle features that also depends on the mask. \nFeatures are typically of shape\n batc"
},
{
"path": "XMem/model/losses.py",
"chars": 2333,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom collections import defaultdict\n\n\ndef dice_loss("
},
{
"path": "XMem/model/memory_util.py",
"chars": 2578,
"preview": "import math\nimport numpy as np\nimport torch\nfrom typing import Optional\n\n\ndef get_similarity(mk, ms, qk, qe):\n # used"
},
{
"path": "XMem/model/modules.py",
"chars": 8817,
"preview": "\"\"\"\nmodules.py - This file stores the rather boring network blocks.\n\nx - usually means features that only depends on the"
},
{
"path": "XMem/model/network.py",
"chars": 8115,
"preview": "\"\"\"\nThis file defines XMem, the highest level nn.Module interface\nDuring training, it is used by trainer.py\nDuring evalu"
},
{
"path": "XMem/model/resnet.py",
"chars": 5538,
"preview": "\"\"\"\nresnet.py - A modified ResNet structure\nWe append extra channels to the first conv by some network surgery\n\"\"\"\n\nfrom"
},
{
"path": "XMem/model/trainer.py",
"chars": 9772,
"preview": "\"\"\"\ntrainer.py - warpper and utility functions for network training\nCompute loss, back-prop, update parameters, logging,"
},
{
"path": "XMem/requirements.txt",
"chars": 43,
"preview": "progressbar2\ngdown\nhickle\ntensorboard\nnumpy"
},
{
"path": "XMem/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/scripts/download_bl30k.py",
"chars": 1608,
"preview": "import os\nimport gdown\nimport tarfile\n\n\nLICENSE = \"\"\"\nThis dataset is a derivative of ShapeNet.\nPlease read and respect "
},
{
"path": "XMem/scripts/download_datasets.py",
"chars": 6097,
"preview": "import os\nimport gdown\nimport zipfile\nfrom scripts import resize_youtube\n\n\nLICENSE = \"\"\"\nThese are either re-distributio"
},
{
"path": "XMem/scripts/download_models.sh",
"chars": 172,
"preview": "wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth\nwget -P ./saves/ https://github.com/"
},
{
"path": "XMem/scripts/download_models_demo.sh",
"chars": 250,
"preview": "wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth\nwget -P ./saves/ https://github.com/"
},
{
"path": "XMem/scripts/expand_long_vid.py",
"chars": 1238,
"preview": "import sys\nimport os\nfrom os import path\nfrom shutil import copy2\n\ninput_path = sys.argv[1]\noutput_path = sys.argv[2]\nmu"
},
{
"path": "XMem/scripts/resize_youtube.py",
"chars": 2289,
"preview": "import sys\nimport os\nfrom os import path\n\nfrom PIL import Image\nimport numpy as np\nfrom progressbar import progressbar\nf"
},
{
"path": "XMem/tracking.py",
"chars": 5006,
"preview": "import sys\nsys.path.insert(0, './XMem')\n\nimport os\nimport os.path as osp\nimport glob\nimport cv2\nimport json\nimport argpa"
},
{
"path": "XMem/train.py",
"chars": 9873,
"preview": "import datetime\nfrom os import path\nimport math\nimport git\n\nimport random\nimport numpy as np\nimport torch\nfrom torch.uti"
},
{
"path": "XMem/util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "XMem/util/configuration.py",
"chars": 7043,
"preview": "from argparse import ArgumentParser\n\n\ndef none_or_default(x, default):\n return x if x is not None else default\n\nclass"
},
{
"path": "XMem/util/davis_subset.txt",
"chars": 592,
"preview": "bear\nbmx-bumps\nboat\nboxing-fisheye\nbreakdance-flare\nbus\ncar-turn\ncat-girl\nclassic-car\ncolor-run\ncrossing\ndance-jump\ndanc"
},
{
"path": "XMem/util/image_saver.py",
"chars": 4340,
"preview": "import cv2\nimport numpy as np\n\nimport torch\nfrom dataset.range_transform import inv_im_trans\nfrom collections import def"
},
{
"path": "XMem/util/load_subset.py",
"chars": 457,
"preview": "\"\"\"\nload_subset.py - Presents a subset of data\nDAVIS - only the training set\nYouTubeVOS - I manually filtered some erron"
},
{
"path": "XMem/util/log_integrator.py",
"chars": 2408,
"preview": "\"\"\"\nIntegrate numerical values for some iterations\nTypically used for loss computation / logging to tensorboard\nCall fin"
},
{
"path": "XMem/util/logger.py",
"chars": 2937,
"preview": "\"\"\"\nDumps things to tensorboard and console\n\"\"\"\n\nimport os\nimport warnings\n\nimport torchvision.transforms as transforms\n"
},
{
"path": "XMem/util/palette.py",
"chars": 2500,
"preview": "davis_palette = b'\\x00\\x00\\x00\\x80\\x00\\x00\\x00\\x80\\x00\\x80\\x80\\x00\\x00\\x00\\x80\\x80\\x00\\x80\\x00\\x80\\x80\\x80\\x80\\x80@\\x00\\"
},
{
"path": "XMem/util/tensor_util.py",
"chars": 1233,
"preview": "import torch.nn.functional as F\n\n\ndef compute_tensor_iu(seg, gt):\n intersection = (seg & gt).float().sum()\n union "
},
{
"path": "XMem/util/yv_subset.txt",
"chars": 38104,
"preview": "003234408d\n0043f083b5\n0044fa5fba\n005a527edd\n0065b171f9\n00917dcfc4\n00a23ccf53\n00ad5016a4\n01082ae388\n011ac0a06f\n013099c098"
},
{
"path": "merge_lora_weights_and_save_hf_model.py",
"chars": 5524,
"preview": "import argparse\nimport glob\nimport os\nimport sys\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn.functional "
},
{
"path": "model/VISA.py",
"chars": 12243,
"preview": "from typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import Bit"
},
{
"path": "model/llava/__init__.py",
"chars": 41,
"preview": "from .model import LlavaLlamaForCausalLM\n"
},
{
"path": "model/llava/constants.py",
"chars": 293,
"preview": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \".\"\n\n# Model Constants\nIGNORE_INDEX = -1"
},
{
"path": "model/llava/conversation.py",
"chars": 15685,
"preview": "import dataclasses\nfrom enum import Enum, auto\nfrom typing import List, Tuple\n\n\nclass SeparatorStyle(Enum):\n \"\"\"Diffe"
},
{
"path": "model/llava/mm_utils.py",
"chars": 2915,
"preview": "import base64\nfrom io import BytesIO\n\nimport torch\nfrom PIL import Image\nfrom transformers import StoppingCriteria\n\nfrom"
},
{
"path": "model/llava/model/__init__.py",
"chars": 149,
"preview": "from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM\nfrom .language_model.llava_mpt import LlavaMP"
},
{
"path": "model/llava/model/apply_delta.py",
"chars": 2053,
"preview": "\"\"\"\nUsage:\npython3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --de"
},
{
"path": "model/llava/model/builder.py",
"chars": 8680,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "model/llava/model/consolidate.py",
"chars": 928,
"preview": "\"\"\"\nUsage:\npython3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate\n"
},
{
"path": "model/llava/model/language_model/llava_llama.py",
"chars": 5652,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "model/llava/model/language_model/llava_mpt.py",
"chars": 6593,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "model/llava/model/language_model/mpt/adapt_tokenizer.py",
"chars": 1785,
"preview": "from typing import Union\n\nfrom transformers import (AutoTokenizer, PreTrainedTokenizer,\n PreTra"
},
{
"path": "model/llava/model/language_model/mpt/attention.py",
"chars": 19787,
"preview": "\"\"\"Attention layers.\"\"\"\nimport math\nimport warnings\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom"
},
{
"path": "model/llava/model/language_model/mpt/blocks.py",
"chars": 3072,
"preview": "\"\"\"GPT Blocks used for the GPT Model.\"\"\"\nfrom typing import Dict, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nf"
},
{
"path": "model/llava/model/language_model/mpt/configuration_mpt.py",
"chars": 10004,
"preview": "\"\"\"A HuggingFace-style model configuration.\"\"\"\nfrom typing import Dict, Optional, Union\n\nfrom transformers import Pretra"
},
{
"path": "model/llava/model/language_model/mpt/custom_embedding.py",
"chars": 308,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\n\nclass SharedEmbedding(nn.E"
},
{
"path": "model/llava/model/language_model/mpt/flash_attn_triton.py",
"chars": 34014,
"preview": "\"\"\"\nCopied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn"
},
{
"path": "model/llava/model/language_model/mpt/hf_prefixlm_converter.py",
"chars": 31405,
"preview": "\"\"\"Converts Huggingface Causal LM to Prefix LM.\n\nConversion does lightweight surgery on a HuggingFace\nCausal LM to conve"
},
{
"path": "model/llava/model/language_model/mpt/meta_init_context.py",
"chars": 3811,
"preview": "from contextlib import contextmanager\n\nimport torch\nimport torch.nn as nn\n\n\n@contextmanager\ndef init_empty_weights(inclu"
},
{
"path": "model/llava/model/language_model/mpt/modeling_mpt.py",
"chars": 22633,
"preview": "\"\"\"A simple, flexible implementation of a GPT model.\n\nInspired by https://github.com/karpathy/minGPT/blob/master/mingpt/"
},
{
"path": "model/llava/model/language_model/mpt/norm.py",
"chars": 3057,
"preview": "import torch\n\n\ndef _cast_if_autocast_enabled(tensor):\n if torch.is_autocast_enabled():\n if tensor.device.type "
},
{
"path": "model/llava/model/language_model/mpt/param_init_fns.py",
"chars": 14271,
"preview": "import math\nimport warnings\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Option"
},
{
"path": "model/llava/model/llava_arch.py",
"chars": 18088,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "model/llava/model/make_delta.py",
"chars": 2386,
"preview": "\"\"\"\nUsage:\npython3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~"
},
{
"path": "model/llava/model/multimodal_encoder/builder.py",
"chars": 517,
"preview": "from .clip_encoder import CLIPVisionTower\n\n\ndef build_vision_tower(vision_tower_cfg, **kwargs):\n vision_tower = getat"
},
{
"path": "model/llava/model/multimodal_encoder/clip_encoder.py",
"chars": 2861,
"preview": "import torch\nimport torch.nn as nn\nfrom transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel\n\n\nclas"
},
{
"path": "model/llava/model/utils.py",
"chars": 971,
"preview": "from transformers import AutoConfig\n\n\ndef auto_upgrade(config):\n cfg = AutoConfig.from_pretrained(config)\n if \"lla"
},
{
"path": "model/llava/train/llama_flash_attn_monkey_patch.py",
"chars": 4581,
"preview": "import logging\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport transformers\nfrom einops import rearrange\nf"
},
{
"path": "model/llava/train/llava_trainer.py",
"chars": 2314,
"preview": "import os\nfrom typing import Optional\n\nimport torch\nfrom transformers import Trainer\n\n\ndef maybe_zero_3(param, ignore_st"
},
{
"path": "model/llava/train/train.py",
"chars": 37120,
"preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
},
{
"path": "model/llava/train/train_mem.py",
"chars": 504,
"preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
},
{
"path": "model/llava/utils.py",
"chars": 4019,
"preview": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nimport requests\nfrom llava.constants import"
},
{
"path": "model/segment_anything/__init__.py",
"chars": 428,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/automatic_mask_generator.py",
"chars": 15372,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/build_sam.py",
"chars": 2980,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/__init__.py",
"chars": 385,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/common.py",
"chars": 1479,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/image_encoder.py",
"chars": 14983,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/mask_decoder.py",
"chars": 8805,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/prompt_encoder.py",
"chars": 9229,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/sam.py",
"chars": 7364,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/modeling/transformer.py",
"chars": 8421,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/predictor.py",
"chars": 11850,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/utils/__init__.py",
"chars": 197,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/utils/amg.py",
"chars": 12712,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/utils/onnx.py",
"chars": 5946,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/segment_anything/utils/transforms.py",
"chars": 4103,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "model/tf/modeling_outputs.py",
"chars": 2635,
"preview": "import torch\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Dict, List\nfrom trans"
},
{
"path": "model/univi/__init__.py",
"chars": 45,
"preview": "from .model import ChatUniViLlamaForCausalLM\n"
},
{
"path": "model/univi/config/__init__.py",
"chars": 295,
"preview": "from .dataset_config import *\nfrom .model_config import *\n\n\nModelConfig = {\n \"PRETUNE\": model_config_pretune,\n \"FI"
},
{
"path": "model/univi/config/dataset_config.py",
"chars": 922,
"preview": "Pretrain = {\n \"chat_path\": \"${PATH}/CC3M-595K/chat.json\",\n \"CC3M\": \"${PATH}/CC3M-595K\",\n}\n\nVIT = {\n \"chat_path\""
},
{
"path": "model/univi/config/model_config.py",
"chars": 536,
"preview": "model_config_pretune = {\n \"use_cluster\": True,\n \"freeze\": False,\n \"vision_tune\": False,\n\n \"spatial_cluster_r"
},
{
"path": "model/univi/constants.py",
"chars": 540,
"preview": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \".\"\n\n# Model Constants\nMAX_IMAGE_LENGTH "
},
{
"path": "model/univi/conversation.py",
"chars": 10524,
"preview": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\n\n\nclass SeparatorStyle(Enum):\n \"\"\"Diffe"
},
{
"path": "model/univi/demo.py",
"chars": 4894,
"preview": "import torch\nfrom .constants import *\nfrom .conversation import conv_templates, SeparatorStyle\nfrom .model.builder impor"
},
{
"path": "model/univi/eval/evaluate/evaluate_benchmark_1_correctness.py",
"chars": 7945,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/evaluate_benchmark_2_detailed_orientation.py",
"chars": 8151,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/evaluate_benchmark_3_context.py",
"chars": 8086,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/evaluate_benchmark_4_temporal.py",
"chars": 8008,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/evaluate_benchmark_5_consistency.py",
"chars": 8692,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/evaluate_gpt_review_visual.py",
"chars": 4206,
"preview": "import argparse\nimport json\nimport os\nimport requests\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\n\n\ndef get_ev"
},
{
"path": "model/univi/eval/evaluate/evaluate_science_qa.py",
"chars": 5534,
"preview": "import argparse\nimport json\nimport os\nimport re\nimport random\nimport numpy as np\n\n\ndef get_args():\n parser = argparse"
},
{
"path": "model/univi/eval/evaluate/evaluate_video_qa.py",
"chars": 8125,
"preview": "import openai\nimport os\nimport argparse\nimport json\nimport jsonlines\nimport ast\nfrom multiprocessing.pool import Pool\n\n\n"
},
{
"path": "model/univi/eval/evaluate/summarize_gpt_review.py",
"chars": 2573,
"preview": "import json\nimport os\nfrom collections import defaultdict\nimport numpy as np\nimport argparse\n\n\ndef parse_args():\n par"
},
{
"path": "model/univi/eval/model_coco_vqa.py",
"chars": 9003,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\nfrom ChatUniVi.constants impor"
},
{
"path": "model/univi/eval/model_video_consistency.py",
"chars": 10084,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\nfrom ChatUniVi.constants impor"
},
{
"path": "model/univi/eval/model_video_general.py",
"chars": 8464,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\nfrom ChatUniVi.constants impor"
},
{
"path": "model/univi/eval/model_video_qa.py",
"chars": 9272,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\nfrom ChatUniVi.constants impor"
},
{
"path": "model/univi/eval/model_vqa.py",
"chars": 5364,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\nfrom ChatUniVi.constants impor"
},
{
"path": "model/univi/eval/model_vqa_scienceqa.py",
"chars": 6798,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom ChatUniVi.constants impo"
},
{
"path": "model/univi/eval/questions/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl",
"chars": 41305,
"preview": "{\"question_id\": 0, \"text\": \"The colors of the two suitcases in the image are black and brown with yellow details.\", \"cat"
},
{
"path": "model/univi/eval/questions/coco2014_val_qa_eval/qa90_questions.jsonl",
"chars": 13412,
"preview": "{\"question_id\": 0, \"image\": \"COCO_val2014_000000441147.jpg\", \"text\": \"What is the color of the two suitcases in the imag"
},
{
"path": "model/univi/eval/questions/coco_pope/coco_pope_adversarial.jsonl",
"chars": 367459,
"preview": "{\"question_id\": 1, \"image\": \"COCO_val2014_000000310196.jpg\", \"text\": \"Is there a snowboard in the image?\", \"label\": \"yes"
},
{
"path": "model/univi/eval/questions/coco_pope/coco_pope_popular.jsonl",
"chars": 367234,
"preview": "{\"question_id\": 1, \"image\": \"COCO_val2014_000000310196.jpg\", \"text\": \"Is there a snowboard in the image?\", \"label\": \"yes"
},
{
"path": "model/univi/eval/questions/coco_pope/coco_pope_random.jsonl",
"chars": 357302,
"preview": "{\"question_id\": 1, \"image\": \"COCO_val2014_000000310196.jpg\", \"text\": \"Is there a snowboard in the image?\", \"label\": \"yes"
},
{
"path": "model/univi/eval/questions/scienceqa/pid_splits.json",
"chars": 502927,
"preview": "{\n \"train\": [\n \"1\",\n \"2\",\n \"3\",\n \"9\",\n \"10\",\n \"12\",\n \"17\",\n \"19\",\n \"20\",\n \"21\",\n \"24\","
},
{
"path": "model/univi/eval/questions/scienceqa/test_QCM-LEA.json",
"chars": 5690207,
"preview": "[\n {\n \"id\": \"4\",\n \"conversations\": [\n {\n \"from\": \"human\",\n \"value\": \"Which figure of speech is"
},
{
"path": "model/univi/eval/questions/video_qa/activitynet_a_list.json",
"chars": 36960,
"preview": "[\n \"no\",\n \"yes\",\n \"day\",\n \"outdoor\",\n \"good looking\",\n \"bit dangerous\",\n \"secondary\",\n \"simple\",\n \"much simpler"
}
]
// ... and 76 more files (download for full content)
About this extraction
This page contains the full source code of the cilinyan/VISA GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 276 files (54.3 MB), approximately 6.4M tokens, and a symbol index with 1459 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.