Showing preview only (972K chars total). Download the full file or copy to clipboard to get everything.
Repository: bio-mlhui/LGRNet
Branch: main
Commit: e97ab7391f2b
Files: 117
Total size: 13.2 MB
Directory structure:
gitextract__wc8fr5u/
├── .gitignore
├── README.md
├── assets/
│ ├── DATA.md
│ ├── INSTALL.md
│ └── MODEL_ZOO.md
├── data_schedule/
│ ├── __init__.py
│ ├── registry.py
│ ├── utils/
│ │ ├── box_ops.py
│ │ ├── sampler.py
│ │ └── segmentation.py
│ └── vis/
│ ├── __init__.py
│ ├── apis.py
│ ├── evaluator_fast.py
│ ├── evaluator_utils.py
│ ├── fibroid/
│ │ ├── __init__.py
│ │ ├── evals.py
│ │ ├── fibroid_dataset.py
│ │ ├── fibroid_utils.py
│ │ └── metrics.py
│ ├── mapper.py
│ ├── mapper_utils.py
│ ├── polyp/
│ │ ├── __init__.py
│ │ ├── evals.py
│ │ ├── polyp_dataset.py
│ │ └── polyp_utils.py
│ ├── vis_aug_eval.py
│ ├── vis_aug_train.py
│ ├── vis_aug_utils.py
│ └── vis_frame_sampler.py
├── handle_vps.py
├── main.py
├── models/
│ ├── VIS/
│ │ ├── BackboneEncoderDecoder_WithScaleConsistency.py
│ │ ├── __init__.py
│ │ └── aux_mapper.py
│ ├── __init__.py
│ ├── backbone/
│ │ ├── __init__.py
│ │ ├── pvtv2.py
│ │ ├── res2net.py
│ │ └── utils.py
│ ├── decoder/
│ │ ├── __init__.py
│ │ └── mask2former_video.py
│ ├── encoder/
│ │ ├── __init__.py
│ │ ├── input_projs.py
│ │ ├── localGlobal.py
│ │ ├── neighborhood_qk.py
│ │ └── ops/
│ │ ├── MultiScaleDeformableAttention.egg-info/
│ │ │ └── PKG-INFO
│ │ ├── attention.py
│ │ ├── build/
│ │ │ ├── lib.linux-x86_64-cpython-311/
│ │ │ │ ├── functions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ └── modules/
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn.py
│ │ │ ├── lib.linux-x86_64-cpython-38/
│ │ │ │ ├── functions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ └── modules/
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn.py
│ │ │ ├── temp.linux-x86_64-cpython-311/
│ │ │ │ ├── .ninja_deps
│ │ │ │ ├── .ninja_log
│ │ │ │ ├── build.ninja
│ │ │ │ └── home/
│ │ │ │ └── xhh/
│ │ │ │ └── workspace/
│ │ │ │ └── rvos_encoder/
│ │ │ │ └── models/
│ │ │ │ └── ops/
│ │ │ │ └── src/
│ │ │ │ ├── cpu/
│ │ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ │ ├── cuda/
│ │ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ │ └── vision.o
│ │ │ └── temp.linux-x86_64-cpython-38/
│ │ │ └── home/
│ │ │ └── xhh/
│ │ │ └── workspace/
│ │ │ └── ReferFormer/
│ │ │ └── models/
│ │ │ └── ops/
│ │ │ └── src/
│ │ │ ├── cpu/
│ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ ├── cuda/
│ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ └── vision.o
│ │ ├── dist/
│ │ │ ├── MultiScaleDeformableAttention-1.0-py3.11-linux-x86_64.egg
│ │ │ └── MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg
│ │ ├── functions/
│ │ │ ├── __init__.py
│ │ │ └── ms_deform_attn_func.py
│ │ ├── make.sh
│ │ ├── modules/
│ │ │ ├── __init__.py
│ │ │ ├── frame_query_ss2d.py
│ │ │ └── ms_deform_attn.py
│ │ ├── setup.py
│ │ ├── src/
│ │ │ ├── cpu/
│ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ ├── cuda/
│ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ ├── ms_deform_attn.h
│ │ │ └── vision.cpp
│ │ └── test.py
│ ├── layers/
│ │ ├── anyc_trans.py
│ │ ├── decoder_layers.py
│ │ ├── gilbert/
│ │ │ ├── demo/
│ │ │ │ ├── index.html
│ │ │ │ ├── normalize.css
│ │ │ │ ├── script.js
│ │ │ │ ├── skeleton.css
│ │ │ │ └── two.js
│ │ │ ├── gilbert2d.py
│ │ │ ├── gilbert3d.py
│ │ │ ├── gilbert_d2xy.py
│ │ │ ├── gilbert_d2xyz.py
│ │ │ ├── gilbert_xy2d.py
│ │ │ ├── gilbert_xyz2d.py
│ │ │ ├── plotpath.m
│ │ │ ├── ports/
│ │ │ │ ├── Makefile
│ │ │ │ ├── gilbert.c
│ │ │ │ └── gilbert.js
│ │ │ ├── test.py
│ │ │ └── tests/
│ │ │ └── runtests.sh
│ │ ├── matching.py
│ │ ├── position_encoding.py
│ │ └── utils.py
│ ├── modality_input_mappers/
│ │ ├── __init__.py
│ │ └── hilbert_curve.py
│ ├── optimization/
│ │ ├── optimizer.py
│ │ └── scheduler.py
│ └── registry.py
├── output/
│ └── VIS/
│ ├── cvc/
│ │ └── pvt.py
│ ├── fibroid/
│ │ └── pvt.py
│ └── sunseg/
│ ├── pvt/
│ │ └── pvt.py
│ └── res/
│ └── res.py
├── reorganize_sunseg.py
├── trainers/
│ ├── Trainer.py
│ └── __init__.py
└── utils/
├── __init__.py
└── misc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
**/__pycache__/**
**/wandb/**
**/.vscode/**
*.out
*.err
*.zip
*.tar
*.pth
stdout_train.txt
stdout_eval.txt
stdout_visualize.txt
hostfile
**/olds/**
**/old/**
*.pt
*.txt
miccai_generate_Dir
generate.py
visualzie.py
================================================
FILE: README.md
================================================
## LGRNet: Local-Global Reciprocal Network for Video Polyp Segmentation [`Paper`](https://arxiv.org/abs/2407.05703) | [`BibTeX`](#citing) | [`Huggingface(UFUV Dataset)`](https://huggingface.co/datasets/huihuixu/uterine_fibroid_ultrasound_video_segmentation)
Huihui Xu, Yijun Yang, Angelica Aviles-Rivero, Guang Yang, Jing Qin, and Lei Zhu
[](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-hard?p=lgrnet-local-global-reciprocal-network-for)
[](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-easy?p=lgrnet-local-global-reciprocal-network-for)
[](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-easy-1?p=lgrnet-local-global-reciprocal-network-for)
[](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-hard-1?p=lgrnet-local-global-reciprocal-network-for)
This is the official implmentation of LGRNet (MICCAI'24 Early Accept), which incorporates local **[Cyclic Neighborhoold Propagation](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/neighborhood_qk.py#L57)** and global **[Hilbert Selective Scan](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/ops/modules/frame_query_ss2d.py#L531)**. Together with the notion of **[Frame Bottleneck Queries](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/localGlobal.py#L185)**, LGRNet can both efficiently and effectively aggregate the local-global temporal context, which achieves *state-of-the-art* on the public [Video Polyp Segmentation(VPS)](https://paperswithcode.com/task/video-polyp-segmentation) benchmark.
<div align="justify">As an example for ultrasound video, a single frame is too noisy and insufficient for accurate lesion diagnosis. In practice, doctors need to check neighboring frames(local) and collect all visual clues (global) in the video to predict possible lesion region and filter out irrelevent surrounding issues. </div>
</br>
<div align="center" style="padding: 0 100pt">
<img src="assets/images/pipeline.png">
</div>
</br>
<div align="justify"> In CNP, each token takes the neighborhood tokens (defined by a kernel) in the cyclic frame as attention keys. CNP enables aggregating the local(cyclic) temporal information into one token. In Hilbert Selective Scan, a set of frame bottleneck queries are used to aggreate spatial information from each frame. Then, we use Hilbert Selective Scan to efficiently parse the global temporal context based on these bottleneck queries. The global temporal context is then propagated back to the feature maps by a Distribute layer. Based on Mask2Former, the decoder can output a set of different mask predictions with corresponding confidence score, which also facilitates comprehesive diagnosis.</div>
## Items
1. Installation: Please refer to [INSTALL.md](assets/INSTALL.md) for more details.
2. Data preparation: Please refer to [DATA.md](assets/DATA.md) for more details.
3. Training:
Change PORT_NUM for DDP and make sure the $CURRENT_TASK is 'VIS':
```
export CURRENT_TASK=VIS
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=PORT_NUM
```
Make sure the $PT_PATH and $DATASET_PATH are correctly set during installation and preparing data.
The training on SUN-SEG is conducted using 2 4090-24GB GPUs:
```
CUDA_VISIBLE_DEVICES=0,1 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode train_attmpt
```
4. logs, checkpoints, predictions
| Backbone| Dataset | Dice | mIou | log | ckpt | predictions |
| :----: | :----: | :----: | :----: | :----: | :----: |:----: |
| PVTv2-B2 | SUN-SEG-Train | -- | -- | [log](https://drive.google.com/file/d/17MTOYW73RLbvZS3BLFBZEphY_0JzN6er/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | --
| PVTv2-B2 | SUN-SEG-Hard-Testing | 0.876 | 0.805 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Easy-Testing | 0.875 | 0.810 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Hard-Unseen-Testing | 0.865 | 0.792 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Easy-Unseen-Testing | 0.853 | 0.783 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing)| [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| Res2Net-50 | SUN-SEG-Hard-Testing | 0.841 | 0.765 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) |
| Res2Net-50 | SUN-SEG-Easy-Testing | 0.843 | 0.774 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) |
| PVTv2-B2 | CVC612V | 0.933 | 0.877 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
| PVTv2-B2 | CVC300TV | 0.916 | 0.852 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
| PVTv2-B2 | CVC612T | 0.875 | 0.814 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
5. Evaluate:
Evaluating on SUN-SEG-Easy AND SUN-SEG-Hard using 1 4090-24GPU GPUS (**modify the ckpt_path to the absolute path**):
```
CUDA_VISIBLE_DEVICES=0 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode eval --eval_path ckpt_path
```
## citing
```
@article{xu2024lgrnet,
title={LGRNet: Local-Global Reciprocal Network for Uterine Fibroid Segmentation in Ultrasound Videos},
author={Xu, Huihui and Yang, Yijun and Aviles-Rivero, Angelica I and Yang, Guang and Qin, Jing and Zhu, Lei},
journal={arXiv preprint arXiv:2407.05703},
year={2024}
}
```
## Acknowledgments
- Thanks [Gilbert](https://github.com/jakubcerveny/gilbert) for the implementation of Hilbert curve generation.
- Thanks GPT4 for helping me constructing idea of Hilbert Filling Curve v.s. Zigzag curve
================================================
FILE: assets/DATA.md
================================================
# Data Preparation
## UFUV (Private):
please email the second author for UFUV dataset if you want, I have no absolute power for UFUV
## VPS (Public)
### CVC/Kvasir/Mayo
We follow [PNS-Net](https://github.com/GewelsJI/PNS-Net) to download the CVC/Kvasir/Mayo dataset. The download link is same as [link](https://drive.google.com/file/d/1TyaRy4c4nHFDa3o2bOl4dP5Z7wes7HV2/view?usp=sharing)
Put MICCAI-VPS-dataset.zip in $DATASET_PATH, then run following script to change the directory structure:
```
cd $DATASET_PATH
unzip -qq MICCAI-VPS-dataset.zip
# cd LGRNet directory
# normalize the VPS data structure
python handle_vps.py
```
Now the structure should be like:
```
${DATASET_PATH}
-- MICCAI-VPS-dataset
-- Kvasir-SEG
-- *
-- VPS-TestSet
-- CVC-ColonDB-300
-- *
-- CVC-ClinicDB-612-Valid
-- *
-- CVC-ClinicDB-612-Test
-- *
-- VPS-TrainSet
-- ASU-Mayo_Clinic
-- Train
-- *
-- CVC-ClinicDB-612
-- Train
-- *
-- CVC-ColonDB-300
-- Train
-- *
```
where * means the following structure:
```
-- Frame
-- vid1
-- img file
-- GT
-- vid1
-- mask file
```
### SUN-SEG
Please follow https://github.com/GewelsJI/VPS/blob/main/docs/DATA_PREPARATION.md to email the author for SUN-SEG.
Put part1, part2, annotation in $DATASET_PATH/SUN-SEG
```
# normalize the directory
unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part1.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive
unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part2.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive
tar -xf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation.tar -C $DATASET_PATH/SUN-SEG/
rm -rf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Unseen/Frame
find $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation -name "._.DS_Store" -type f -delete
python reorganize_sunseg.py
```
Now the structure should be like:
```
${DATASET_PATH}
-- SUN-SEG
-- SUN-SEG-Annotation
-- TrainDataset
-- *
-- TestEasyDataset
-- combine
-- *
-- TestHardDataset
-- combine
-- *
```
================================================
FILE: assets/INSTALL.md
================================================
# Install
## Requirements
We test the codes in the following environments
- CUDA 12.1
- Python 3.10.13
- Pytorch 2.1.1
- Torchvison 0.16.1
- detectron 0.6
- mamba_ssm 1.2.0.post1
- natten 0.15.1
- timm 0.9.12
## Install environment for LGRNet
```
conda create --name lgrnet python=3.10
conda activate lgrnet
# make sure CUDA-12.1 is installed and activated in env var.
# install torch
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
# install detectron2, building may take much time.
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
# install mamba, see https://github.com/state-spaces/mamba
cd ..
git clone https://github.com/state-spaces/mamba.git
cd mamba
pip install . --no-build-isolation
cd ../LGRNet
# install natten, see https://github.com/SHI-Labs/NATTEN/blob/main/docs/install.md
pip install natten==0.15.1+torch210cu121 -f https://shi-labs.com/natten/wheels/cu121/torch2.1.0/natten-0.15.1%2Btorch210cu121-cp310-cp310-linux_x86_64.whl
# misc
pip install albumentations==1.3.1
pip install Pygments
pip install imgaug
pip install timm==0.9.12
# compile deform attention
cd models/encoder/ops/
python setup.py build install --user
# download resnet/pvtv2 ckpt, our model uses the same backbone with WeakPoly(phttps://github.com/weijun88/WeakPolyp)
wget -P $PT_PATH/pvt_v2/pvt_v2_b2.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/pvt_v2_b2.pth
wget -P $PT_PATH/res2net/res2net50_v1b_26w_4s-3cf99910.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/res2net50_v1b_26w_4s-3cf99910.pth
```
================================================
FILE: assets/MODEL_ZOO.md
================================================
================================================
FILE: data_schedule/__init__.py
================================================
import os
if os.getenv('CURRENT_TASK') == 'VIS':
from . import vis
else:
raise ValueError()
def build_schedule(configs, model_input_mapper, model_input_collate_fn):
import logging
from functools import partial
import detectron2.utils.comm as comm
from torch.utils.data import DataLoader, ConcatDataset
from .registry import MAPPER_REGISTRY, EVALUATOR_REGISTRY
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset, MetadataCatalog
from data_schedule.utils.sampler import Evaluate_ExactSampler_Distributed, Train_InfiniteSampler_Distributed
datasets = {'train': [], 'evaluate': []}
meta_idx_shift = 0
for mode in ['train', 'evaluate']:
for dataset_name in configs['data'][mode].keys():
dataset_assume_mode = MetadataCatalog.get(dataset_name).get('mode')
if dataset_assume_mode != mode:
logging.warning(f'default mode of {dataset_name} is {dataset_assume_mode} not {mode}')
dataset_dicts = DatasetFromList(DatasetCatalog.get(dataset_name), copy=True, serialize=True)
mapper = MAPPER_REGISTRY.get(configs['data'][mode][dataset_name]['mapper']['name'])(mode=mode,
dataset_name=dataset_name,
configs=configs,
meta_idx_shift=meta_idx_shift if mode == 'train' else 0)
meta_idx_shift += len(dataset_dicts)
dataset = MapDataset(dataset_dicts, partial(composition, mappers=[mapper,
partial(model_input_mapper, mode=mode)]))
if mode == 'train':
datasets[mode].append(dataset)
else:
datasets[mode].append((dataset_name, dataset))
train_dataset = ConcatDataset(datasets['train'])
logging.debug(f'Total number of training meta: {len(train_dataset)}')
train_loader_splits = configs['optim']['splits']
batch_sizes = configs['optim']['batch_sizes']
splits = list(zip(train_loader_splits[:-1], train_loader_splits[1:]))
assert len(splits) == (len(batch_sizes))
inf_stream_fn = partial(infinite_indices,
seed=configs['stream_idx_seed'],
batch_sizes=configs['optim']['batch_sizes'],
splits=configs['optim']['splits'],
one_batch_two_epoch=configs['optim']['one_batch_two_epoch'],
dataset_length=len(train_dataset),
shuffle=True)
train_samplers = []
train_loaders = []
for btch_size, (range_start, range_end) in zip(batch_sizes, splits):
if range_end is not None:
assert (range_end - range_start) % btch_size == 0, ''
assert btch_size % comm.get_world_size() == 0, ''
each_process_batch_size = int(btch_size / comm.get_world_size())
loader_sampler = Train_InfiniteSampler_Distributed(inf_stream_fn=inf_stream_fn,
start_idx=range_start,
end_idx=range_end,)
train_samplers.append(loader_sampler)
train_loaders.append(DataLoader(train_dataset,
batch_size=each_process_batch_size,
sampler=loader_sampler,
collate_fn=partial(model_input_collate_fn, mode='train'),
num_workers=int(os.getenv('TORCH_NUM_WORKERS')),
pin_memory=True,
persistent_workers=True))
evaluators = []
for eval_dataset_name, eval_dataset in datasets['evaluate']:
logging.debug(f'Number of evaluate meta in {eval_dataset_name}: {len(eval_dataset)}')
loader = DataLoader(eval_dataset,
batch_size=1,
sampler=Evaluate_ExactSampler_Distributed(eval_dataset),
collate_fn=partial(model_input_collate_fn, mode='evaluate'),
num_workers=int(os.getenv('TORCH_NUM_WORKERS')),
pin_memory=True,
persistent_workers=True)
evaluator = EVALUATOR_REGISTRY.get(configs['data']['evaluate'][eval_dataset_name]['evaluator']['name'])(configs=configs,
dataset_name=eval_dataset_name,
data_loader=loader)
evaluators.append((eval_dataset_name, evaluator))
return train_samplers, train_loaders, partial(evaluate_call, evaluators=evaluators)
def composition(data_dict, mappers):
for mappper in mappers:
data_dict = mappper(data_dict)
if data_dict is None:
return None
return data_dict
def evaluate_call(evaluators, model, output_dir):
import detectron2.utils.comm as comm
ret = {}
for eval_dataset_name, evaluator in evaluators:
metric_dict = evaluator(model=model,output_dir=output_dir)
if comm.is_main_process():
for key, value in metric_dict.items():
assert f'{key}_{eval_dataset_name}' not in ret
ret[f'{key}_{eval_dataset_name}'] = value
comm.synchronize()
return ret
def _infinite_indices(seed, dataset_length, shuffle=True,):
import torch
g = torch.Generator()
g.manual_seed(seed)
while True:
if shuffle:
yield from torch.randperm(dataset_length, generator=g).tolist()
else:
yield from torch.arange(dataset_length).tolist()
def infinite_indices(seed,
dataset_length,
batch_sizes,
splits,
one_batch_two_epoch='just_use',
shuffle=True): # 'abandon', 'just_use', 'pad'
import torch
import math
g = torch.Generator()
g.manual_seed(seed)
split_ranges = list(zip(splits[:-1], splits[1:]))
assert len(split_ranges) == (len(batch_sizes))
stream = _infinite_indices(seed, dataset_length=dataset_length, shuffle=shuffle)
stream_throw_cnt = 0
cnt = 0
for (range_start, range_end), btch_size in zip(split_ranges, batch_sizes):
assert cnt == range_start
if range_end == None:
range_end = math.inf
while cnt < range_end:
epoch_milestone = ((stream_throw_cnt // dataset_length) + 1 ) * dataset_length
if (stream_throw_cnt < epoch_milestone) and (stream_throw_cnt + btch_size > epoch_milestone) and (one_batch_two_epoch != 'just_use'):
if one_batch_two_epoch == 'abandon':
for _ in range(epoch_milestone - stream_throw_cnt):
abandon = next(stream)
stream_throw_cnt += 1
elif one_batch_two_epoch == 'pad':
diff = stream_throw_cnt + btch_size - epoch_milestone
num_throw = btch_size - diff
rand_idxs = torch.randperm(dataset_length, generator=g)[:diff].tolist()
for _ in range(num_throw):
cnt += 1
stream_throw_cnt += 1
yield next(stream)
for idx in rand_idxs:
cnt += 1
yield idx
else:
raise ValueError()
else:
for _ in range(btch_size):
cnt += 1
stream_throw_cnt += 1
yield next(stream)
assert cnt == range_end
================================================
FILE: data_schedule/registry.py
================================================
from detectron2.utils.registry import Registry
EVALUATOR_REGISTRY = Registry('EVALUATOR')
MAPPER_REGISTRY = Registry('MAPPER')
class Mapper:
def __init__(self,
meta_idx_shift,
dataset_meta,) -> None:
self.meta_idx_shift = meta_idx_shift
self.visualized_meta_idxs = dataset_meta.get('visualize_meta_idxs')
def _call(self, data_dict):
pass
def __call__(self, data_dict):
meta_idx = data_dict['meta_idx']
ret = self._call(data_dict)
if ret is None:
return None
ret['meta_idx'] = meta_idx + self.meta_idx_shift
if meta_idx in self.visualized_meta_idxs:
ret['visualize'] = True
else:
ret['visualize'] = False
return ret
================================================
FILE: data_schedule/utils/box_ops.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
assert ((x1 - x0) >= 0).all()
assert ((y1 - y0) >= 0).all()
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
================================================
FILE: data_schedule/utils/sampler.py
================================================
import math
import torch.distributed as dist
from typing import TypeVar, Optional, Iterator
T_co = TypeVar('T_co', covariant=True)
from torch.utils.data.distributed import DistributedSampler
import detectron2.utils.comm as comm
from utils.misc import all_gather
from torch.utils.data import Sampler
import torch
import logging
import itertools
class TrainRandomSampler_ByEpoch(Sampler[int]):
def __init__(self,
data_source,
seed,
) -> None:
self.data_source = data_source
self.num_samples = len(self.data_source)
self.seed = seed
self.epoch = None
def __iter__(self):
seed = self.seed + self.epoch
print(f'generating a new indices permutations for this epoch using seed {seed}')
n = len(self.data_source)
g = torch.Generator()
g.manual_seed(seed)
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=g).tolist()
yield from torch.randperm(n, generator=g).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class Train_InfiniteSampler_Distributed(Sampler[T_co]):
def __init__(self,
inf_stream_fn,
start_idx: int = 0,
end_idx = None,
):
self.rank = comm.get_rank()
self.num_replicas = comm.get_world_size()
self.start_idx = start_idx
self.end_idx = end_idx
self.inf_stream_fn = inf_stream_fn
def set_iter_first_sample_idx(self, idx):
self.start_idx = idx
def set_iter_last_sample_idx(self, idx):
self.end_idx = idx
def __iter__(self) -> Iterator[T_co]:
logging.debug(f'在 infinite stream 上定位到{self.start_idx} 为开头')
yield from itertools.islice(self.inf_stream_fn(), self.start_idx + self.rank, self.end_idx, self.num_replicas)
class Evaluate_ExactSampler_Distributed(Sampler[T_co]):
def __init__(self, dataset) -> None:
self.dataset = dataset
self.rank = comm.get_rank()
self.num_replicas = comm.get_world_size()
indices = list(range(len(self.dataset)))
self.indices = indices[self.rank:len(self.dataset):self.num_replicas]
def __iter__(self):
yield from self.indices
def __len__(self):
return len(self.indices)
class TrainRandomSampler_ByEpoch_Distributed(Sampler[T_co]):
def __init__(self,
dataset, num_replicas,
rank,
seed: int = 0) -> None:
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"" [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = None
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
seed = self.seed + self.epoch
logging.debug(f'generating a new indices permutations for this epoch using seed {seed}')
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
self.epoch = None
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class InferenceSampler(Sampler):
"""
Produce indices for inference across all workers.
Inference needs to run on the __exact__ set of samples,
therefore when the total number of samples is not divisible by the number of workers,
this sampler produces different number of samples on different workers.
"""
def __init__(self, size: int):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
"""
self._size = size
assert size > 0
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
================================================
FILE: data_schedule/utils/segmentation.py
================================================
import torch
def bounding_box_from_mask(mask):
if not mask.any():
return torch.zeros([4]).float()
rows = torch.any(mask, dim=1) # h
cols = torch.any(mask, dim=0) # w
row_indexs = torch.where(rows)[0]
rmin, rmax = row_indexs.min(), row_indexs.max()
col_indexs = torch.where(cols)[0]
cmin, cmax = col_indexs.min(), col_indexs.max()
return torch.tensor([cmin, rmin, cmax, rmax]).float() # x1y1x2y2
================================================
FILE: data_schedule/vis/__init__.py
================================================
from . import polyp
from . import mapper
from . import evaluator_fast
from . import vis_aug_eval
from . import vis_aug_train
from . import vis_frame_sampler
================================================
FILE: data_schedule/vis/apis.py
================================================
class VIS_Dataset:
"""
"""
class VIS_Aug_CallbackAPI:
"""
"""
class VIS_Evaluator_OutAPI_EvalFn_API:
"""
"""
class VIS_TrainAPI_clipped_video:
"""
"""
class VIS_EvalAPI_clipped_video_request_ann:
"""
"""
class VIS_FrameSampler_InputOutput_API:
"""
"""
class GetFrames:
"""
"""
================================================
FILE: data_schedule/vis/evaluator_fast.py
================================================
import os
from tqdm import tqdm
from functools import partial
import torch
import detectron2.utils.comm as comm
from utils.misc import to_device
from detectron2.data import MetadataCatalog
from data_schedule.registry import EVALUATOR_REGISTRY
from .evaluator_utils import vis_metric_entrypoint
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann, VIS_Aug_CallbackAPI, VIS_Evaluator_OutAPI_EvalFn_API
from collections import defaultdict
@EVALUATOR_REGISTRY.register()
class VIS_Evaluator_FrameFast:
def __init__(self,
dataset_name,
data_loader,
configs) -> None:
self.dataset_name = dataset_name
self.loader = data_loader
frame_metrics = configs['data']['evaluate'][dataset_name]['evaluator']['frame_metrics']
dataset_meta = MetadataCatalog.get(dataset_name)
self.frame_metric_fns = []
for metric_name, metric_config in frame_metrics:
metric_fn = vis_metric_entrypoint(metric_name)
metric_fn = partial(metric_fn, dataset_meta=dataset_meta, **metric_config)
self.frame_metric_fns.append(metric_fn)
self.eval_meta_keys = dataset_meta.get('eval_meta_keys')
metrics_aggregator = configs['data']['evaluate'][dataset_name]['evaluator']['metrics_aggregator']
self.eval_meta_keys = dataset_meta.get('eval_meta_keys') # { video_id: list[fnames] }
self.metrics_aggregator = partial(vis_metric_entrypoint(metrics_aggregator[0]),
dataset_meta=dataset_meta,
eval_meta_keys=self.eval_meta_keys,
**metrics_aggregator[1])
def visualize_path(self, meta_idxs, visualize, evaluator_path):
return [os.path.join(evaluator_path, f'meta_{meta_idx}') if vis else None for (meta_idx, vis) in zip(meta_idxs, visualize)]
@torch.no_grad()
def __call__(self, model, output_dir):
evaluator_path = os.path.join(output_dir, f'eval_{self.dataset_name}')
os.makedirs(evaluator_path, exist_ok=True)
macs, params = None, None
metrics_by_video_id_frame = defaultdict(dict)
for batch_dict in tqdm(self.loader):
VIS_EvalAPI_clipped_video_request_ann
eval_metas = batch_dict.pop('metas')
request_anns = eval_metas['request_ann'][0] # t, bool tensor
frame_strs = eval_metas['frames'][0] # t', list[str]
video_id = eval_metas['video_id'][0] # str
assert request_anns.int().sum() == len(frame_strs)
callback_fns = eval_metas['callback_fns'][0] # list[fn]
visualize_path = self.visualize_path(meta_idxs=batch_dict['meta_idxs'], visualize=batch_dict['visualize'],
evaluator_path=os.path.join(evaluator_path, 'visualize_model'))
batch_dict['visualize_paths'] = visualize_path
batch_dict = to_device(batch_dict, device=model.device)
VIS_Aug_CallbackAPI
# if macs is None:
# from detectron2.utils.analysis import (
# FlopCountAnalysis,
# )
# flops = FlopCountAnalysis(model, batch_dict, inference_func=lambda model, *inputs: model.sample(*inputs))
# total_flops = flops.total()
# # counts = flops.by_operator()
# logging.debug(f'macs: {total_flops/ (10**9) / len(request_anns)}')
model_outputs = model.sample(batch_dict)
predictions = {
'video': model_outputs['video'][0], # t 3 h w
'pred_masks': [haosen for idx, haosen in enumerate(model_outputs['pred_masks'][0]) if request_anns[idx]], # list[nt h w], t'
'pred_class': [haosen for idx, haosen in enumerate(model_outputs['pred_class'][0]) if request_anns[idx]], # list[nt c], t',
}
if 'pred_boxes' in model_outputs:
predictions.update({'pred_boxes': [haosen for idx, haosen in enumerate(model_outputs['pred_boxes'][0]) if request_anns[idx]]}) # # list[nt 4], t,
for cardib in callback_fns:
predictions = cardib(predictions)
pred_masks = predictions['pred_masks']
pred_class = predictions['pred_class']
assert len(frame_strs) == len(pred_masks)
for idx, (fname, fmk, fclass) in enumerate(zip(frame_strs, pred_masks, pred_class)):
VIS_Evaluator_OutAPI_EvalFn_API
frame_pred = {'masks': fmk, 'classes': fclass.tolist(), 'video_id': video_id, 'frame_name': fname}
if 'pred_boxes' in predictions:
frame_pred.update({'boxes': predictions['pred_boxes'][idx]})
meta_key_metrics = {}
for metric_fn in self.frame_metric_fns:
metric_values = metric_fn(frame_pred=frame_pred, output_dir=evaluator_path)
for key, value in metric_values.items():
assert key not in meta_key_metrics
meta_key_metrics[key] = value
assert fname not in metrics_by_video_id_frame[video_id]
metrics_by_video_id_frame[video_id][fname] = meta_key_metrics
metrics_by_video_id_frame = comm.gather(dict(metrics_by_video_id_frame), dst=0)
eval_metrics = {}
if comm.is_main_process():
metrics_by_video = {}
for video_id in tqdm(self.eval_meta_keys.keys(), desc='gathering different processes'):
video_id_metrics = [haosen[video_id] for haosen in metrics_by_video_id_frame if video_id in haosen]
video_id_frame_names = [list(haosen.keys()) for haosen in video_id_metrics]
merged_video_id_frame_names = [item for sublist in video_id_frame_names for item in sublist]
assert len(set(merged_video_id_frame_names)) == len(merged_video_id_frame_names),''
assert set(merged_video_id_frame_names).issubset(set(self.eval_meta_keys[video_id]))
assert set(self.eval_meta_keys[video_id]).issubset(set(merged_video_id_frame_names))
# perframe metrics frame: predictions
vid_frame_metrics = video_id_metrics[0]
for haosen in video_id_metrics[1:]:
vid_frame_metrics.update(haosen)
metrics_by_video[video_id] = vid_frame_metrics
eval_metrics = self.metrics_aggregator(metrics_by_video)
comm.synchronize()
return eval_metrics
================================================
FILE: data_schedule/vis/evaluator_utils.py
================================================
_vis_metric_entrypoints = {}
def register_vis_metric(fn):
vis_metric_name = fn.__name__
if vis_metric_name in _vis_metric_entrypoints:
raise ValueError(f'vis_metric name {vis_metric_name} has been registered')
_vis_metric_entrypoints[vis_metric_name] = fn
return fn
def vis_metric_entrypoint(vis_metric_name):
try:
return _vis_metric_entrypoints[vis_metric_name]
except KeyError as e:
print(f'vis_metric Name {vis_metric_name} not found')
import numpy as np
_EPS = np.spacing(1)
_TYPE = np.float64
def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
gt = gt > 128
pred = pred / 255
if pred.max() != pred.min():
pred = (pred - pred.min()) / (pred.max() - pred.min())
return pred, gt
class Smeasure(object):
def __init__(self, length, alpha: float = 0.5):
self.sms = []
self.alpha = alpha
def step(self, pred: np.ndarray, gt: np.ndarray, idx):
pred, gt = _prepare_data(pred=pred, gt=gt)
sm = self.cal_sm(pred, gt)
self.sms.append(sm)
def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
y = np.mean(gt)
if y == 0:
sm = 1 - np.mean(pred)
elif y == 1:
sm = np.mean(pred)
else:
sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
sm = max(0, sm)
return sm
def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
fg = pred * gt
bg = (1 - pred) * (1 - gt)
u = np.mean(gt)
object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt)
return object_score
def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
x = np.mean(pred[gt == 1])
sigma_x = np.std(pred[gt == 1], ddof=1)
score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS)
return score
def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
x, y = self.centroid(gt)
part_info = self.divide_with_xy(pred, gt, x, y)
w1, w2, w3, w4 = part_info['weight']
pred1, pred2, pred3, pred4 = part_info['pred']
gt1, gt2, gt3, gt4 = part_info['gt']
score1 = self.ssim(pred1, gt1)
score2 = self.ssim(pred2, gt2)
score3 = self.ssim(pred3, gt3)
score4 = self.ssim(pred4, gt4)
return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
def centroid(self, matrix: np.ndarray) -> tuple:
"""
To ensure consistency with the matlab code, one is added to the centroid coordinate,
so there is no need to use the redundant addition operation when dividing the region later,
because the sequence generated by ``1:X`` in matlab will contain ``X``.
:param matrix: a bool data array
:return: the centroid coordinate
"""
h, w = matrix.shape
area_object = np.count_nonzero(matrix)
if area_object == 0:
x = np.round(w / 2)
y = np.round(h / 2)
else:
# More details can be found at: https://www.yuque.com/lart/blog/gpbigm
y, x = np.argwhere(matrix).mean(axis=0).round()
return int(x) + 1, int(y) + 1
def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
h, w = gt.shape
area = h * w
gt_LT = gt[0:y, 0:x]
gt_RT = gt[0:y, x:w]
gt_LB = gt[y:h, 0:x]
gt_RB = gt[y:h, x:w]
pred_LT = pred[0:y, 0:x]
pred_RT = pred[0:y, x:w]
pred_LB = pred[y:h, 0:x]
pred_RB = pred[y:h, x:w]
w1 = x * y / area
w2 = y * (w - x) / area
w3 = (h - y) * x / area
w4 = 1 - w1 - w2 - w3
return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB),
pred=(pred_LT, pred_RT, pred_LB, pred_RB),
weight=(w1, w2, w3, w4))
def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
h, w = pred.shape
N = h * w
x = np.mean(pred)
y = np.mean(gt)
sigma_x = np.sum((pred - x) ** 2) / (N - 1)
sigma_y = np.sum((gt - y) ** 2) / (N - 1)
sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1)
alpha = 4 * x * y * sigma_xy
beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
if alpha != 0:
score = alpha / (beta + _EPS)
elif alpha == 0 and beta == 0:
score = 1
else:
score = 0
return score
def get_results(self):
sm = np.mean(np.array(self.sms, dtype=_TYPE))
return dict(Smeasure=sm)
import torch
import os
from PIL import Image
@register_vis_metric
def mask_dice_iou(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1)
iou = (inter+1)/(union-inter+1)
return {'dice': dice, 'iou': iou}
@register_vis_metric
def mask_dice_iou_sen_mae_smeasure(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
# tp, tp*2 + fp + fn
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1) # 2*tp / tp + tp + fp + fn
iou = (inter+1)/(union-inter+1) # tp / tp + fp + fn
tp = (pred_mask * gt_mask).sum().float()
fp = (pred_mask.sum() - tp).float()
fn = (gt_mask.sum() - tp).float()
tn = (pred_mask.shape[0] * pred_mask.shape[1] - (tp + fp + fn)).float()
their_dice = tp * 2 / (tp + fp + fn + tp)
their_iou = tp / (tp + fp + fn)
# their_spe = tn / (tn + fp)
their_sen = tp / (tp + fn)
their_mae = (pred_mask.float() - gt_mask.float()).abs().mean()
Np = gt_mask.sum()
Nn = gt_mask.shape[0] * gt_mask.shape[1] - Np
null = Smeasure(length=1, alpha=0.5)
null.step(pred=(pred_mask.float() * 255 ).numpy(), gt=(gt_mask.float() * 255).numpy(), idx=None)
their_smeasure = torch.tensor(null.get_results()['Smeasure']).float()
return {'dice': dice, 'iou': iou,
'their_dice': their_dice,
'their_iou': their_iou,
'their_sen': their_sen,
'their_mae_abs': their_mae,
'their_smeasure': their_smeasure,
'tp': tp, # true positive
'fp': fp, # false positive
'fn': fn, # false negative
'tn': tn, # true negative
'Np': Np, # positive accumulation
'Nn': Nn} # negative accumulation
@register_vis_metric
def web(frame_pred, output_dir, **kwargs):
os.makedirs(os.path.join(output_dir, 'web'), exist_ok=True)
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
mask = Image.fromarray(255 * pred_mask.int().numpy()).convert('L')
save_path = os.path.join(output_dir, 'web', video_id)
os.makedirs(save_path, exist_ok=True)
png_path = os.path.join(save_path, f'{frame_name}.png')
if os.path.exists(png_path):
os.remove(png_path)
mask.save(png_path)
return {}
================================================
FILE: data_schedule/vis/fibroid/__init__.py
================================================
# 注册fibrois数据集
from . import fibroid_dataset
# 注册fibroid评估标准
from . import evals
================================================
FILE: data_schedule/vis/fibroid/evals.py
================================================
from data_schedule.vis.evaluator_utils import register_vis_metric
import os
from glob import glob
from tqdm import tqdm
import shutil
from functools import partial
from PIL import Image
import numpy as np
import torch
import detectron2.utils.comm as comm
import logging
import pycocotools.mask as mask_util
from pycocotools.mask import decode as decode_rle
import data_schedule.vis.fibroid.metrics as metrics
@register_vis_metric
def fibroid_other_medi(model_preds,
dataset_meta,
**kwargs):
assert comm.is_main_process()
iou_by_test_sample = []
dice_by_test_sample = []
preds_by_test_sample = []
gt_by_test_sample = []
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
for pred in model_preds:
video_id = pred['video_id'] # str
frame_name = pred['frame_name'] # list[str], t'
masks = pred['masks']# list[rle], nq
scores = pred['scores'] # nq
max_idx = torch.tensor(scores).argmax()
pred_mask = masks[max_idx] # rle
pred_mask = decode_rle(pred_mask)
pred_mask = torch.as_tensor(pred_mask, dtype=torch.uint8).contiguous() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # 0/1
preds_by_test_sample.append(pred_mask)
gt_by_test_sample.append(gt_mask)
tp, fp, fn, tn = metrics.get_stats(pred_mask[None, None, ...], gt_mask[None, None, ...],
mode='binary')
iou_score = metrics.iou_score(tp, fp, fn, tn, reduction='micro')
dice = metrics.dice(tp, fp, fn, tn, reduction='micro')
iou_by_test_sample.append(iou_score)
dice_by_test_sample.append(dice)
mean_iou = torch.tensor(iou_by_test_sample).mean()
mean_dice = torch.tensor(dice_by_test_sample).mean()
preds_by_test_sample = torch.stack(preds_by_test_sample, dim=0).unsqueeze(1) # N 1 h w
gt_by_test_sample = torch.stack(gt_by_test_sample, dim=0).unsqueeze(1) # N 1 h w
tp, fp, fn, tn = metrics.get_stats(preds_by_test_sample, gt_by_test_sample,
mode='binary')
overall_iou = metrics.iou_score(tp, fp, fn, tn, reduction='micro')
recall = metrics.recall(tp, fp, fn, tn, reduction='micro-imagewise')
precision = metrics.precision(tp, fp, fn, tn, reduction='micro-imagewise')
all_medi = {
'mean_iou': mean_iou,
'dice': mean_dice,
'overall_iou': overall_iou, # J/overallIoU
'recall': recall,
'precision': precision,
'F': 2 * precision * recall / (precision + recall)
}
return all_medi
from collections import defaultdict
# by_vid, by_frame
iou_dict = defaultdict(dict)
@register_vis_metric
def fibroid_mask_dice_iou(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c, 保证c是2
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1)
iou = (inter+1)/(union-inter+1)
iou_dict[video_id][frame_name] = iou
if iou > 0.6:
print(f'video_id: {video_id}, frame: {frame_name}: dice {dice}, iou {iou}')
return {'dice': dice, 'iou': iou}
@register_vis_metric
def fibroid_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs):
# output: eval_metrics
# video: frame_name: metric/ vid_metrics
eval_metrics = {}
# video, frame_name
# perframe metrics
metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]]
for taylor_swift in metric_names:
eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean()
# metrics by each video
mean_iou_by_each_video = {}
mean_dice_by_each_video = {}
for video in eval_meta_keys:
mean_iou_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['iou'] for fname in eval_meta_keys[video]]).mean()
mean_dice_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['dice'] for fname in eval_meta_keys[video]]).mean()
mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1]))
mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1]))
logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}')
logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}')
return eval_metrics
================================================
FILE: data_schedule/vis/fibroid/fibroid_dataset.py
================================================
from typing import Optional, Union
import json
import os
from functools import partial
import numpy as np
import torch
import logging
from tqdm import tqdm
import copy
from detectron2.data import DatasetCatalog, MetadataCatalog
from collections import defaultdict
from data_schedule.vis.apis import VIS_Dataset
from .fibroid_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR,\
SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX, SET_NAME_TO_GT_TYPE
def fibroid_train(step_size, # none / int; 0, 6, 13, 19 ...
split_dataset_name,
video_ids,
video_to_frames):
logging.debug(f'{split_dataset_name} Generating metas...')
metas = []
for vid_id in tqdm(video_ids):
all_frames = sorted(video_to_frames[vid_id])
if step_size is None:
metas.append({
'video_id': vid_id,
'all_frames' : all_frames,
'meta_idx': len(metas),
'all_objs': {1: {'class_label': 0,}} # 语义分割
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': vid_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'all_objs': {1: {'class_label': 0,}},
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
def fibroid_evaluate(eval_video_ids,
split_dataset_name,
step_size,
video_to_frames):
if (step_size is not None) and (step_size > 1):
logging.warning('为什么 evaluate的时候step size大于1呢')
raise ValueError()
metas = []
for video_id in eval_video_ids:
VIS_Dataset
all_frames = sorted(video_to_frames[video_id])
if step_size == None:
metas.append({
'video_id': video_id,
'all_frames': all_frames,
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': video_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
_root = os.getenv('DATASET_PATH')
root = os.path.join(_root, 'uterus_myoma/Dataset')
visualize_meta_idxs = defaultdict(list)
visualize_meta_idxs['fibroid_train_step[6]'] = []
visualize_meta_idxs['fibroid_train'] = []
visualize_meta_idxs['fibroid_train_ste[1]'] = []
visualize_meta_idxs['fibroid_validate'] = []
visualize_meta_idxs['fibroid_validate_step[1]'] = []
visualize_meta_idxs['weakPolyP_fibroid_validate_step[1]'] = []
fibroid_meta = {
'thing_classes': ['rumor', 'not rumor'],
'thing_colors': [(255., 140., 0.), (0., 255., 0.)],
}
for name in SET_NAME:
set_dir = SET_NAME_TO_DIR[name]
set_dir = os.path.join(root, set_dir)
num_videos = SET_NAME_TO_NUM_VIDEOS[name]
video_ids = os.listdir(os.path.join(set_dir, 'Frame'))
assert len(video_ids) == num_videos
video_to_frames = {
vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.png')])\
for vid in video_ids
}
mode = SET_NAME_TO_MODE[name]
prefix = SET_NAME_TO_PREFIX[name]
if mode == 'train':
train_meta = copy.deepcopy(fibroid_meta)
gt_type = SET_NAME_TO_GT_TYPE[name]
train_meta.update({
'mode': 'train',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')),
'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, gt_type),),
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
})
# train
for step_size in [1, 6, None]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
train_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(fibroid_train,
video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames,))
MetadataCatalog.get(split_name).set(**train_meta,
step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
elif mode == 'evaluate':
prefix = SET_NAME_TO_PREFIX[name]
validate_meta = copy.deepcopy(fibroid_meta)
validate_meta.update({
'mode': 'evaluate',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))),
'eval_set_name': SET_NAME_TO_DIR[name],
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
'eval_meta_keys': video_to_frames
})
# validate
for step_size in [1, None,]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
validate_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(fibroid_evaluate,
eval_video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames))
MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
================================================
FILE: data_schedule/vis/fibroid/fibroid_utils.py
================================================
import wandb
import plotly.express as px
import logging
import os
import numpy as np
import torch
import json
from joblib import Parallel, delayed
import multiprocessing
import torch.distributed as dist
import detectron2.utils.comm as comm
import pycocotools.mask as mask_util
from pycocotools.mask import encode, area
from data_schedule.utils.segmentation import bounding_box_from_mask
from data_schedule.utils.video_clips import generate_windows_of_video
from glob import glob
from PIL import Image
def get_frames(frames_path, video_id, frames):
return [Image.open(os.path.join(frames_path, video_id, f'{f}.png'),).convert('RGB') for f in frames]
# t' h w, int, obj_ids ; has_ann t
def get_frames_mask(mask_path, video_id, frames):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
SET_NAME = [
'fibroid_train',
'fibroid_validate',
'weakpolyp_train',
'fibroid_validate_temp7',
'fibroid_train_temp7',
# 'weakpolyp_fibroid_train_temp7',
'fibroid_validate_temp8',
'fibroid_train_temp8',
'weakpolyp_fibroid_train_temp8'
]
SET_NAME_TO_DIR = {
'fibroid_train': 'temp/train',
'fibroid_validate': 'temp/test',
'weakpolyp_train': 'temp/uterus_myoma_WeakPolyP_temp/train',
'fibroid_validate_temp7': 'temp7/test',
'fibroid_train_temp7': 'temp7/train',
'weakpolyp_fibroid_train_temp7': 'temp7/uterus_myoma_WeakPolyP_temp7/train',
'fibroid_validate_temp8': 'temp8/test',
'fibroid_train_temp8': 'temp8/train',
'weakpolyp_fibroid_train_temp8': 'temp8/uterus_myoma_WeakPolyP_temp8/train',
}
SET_NAME_TO_NUM_VIDEOS = {
'fibroid_train': 80,
'fibroid_validate': 20,
'weakpolyp_train': 80,
'fibroid_train_temp7': 85,
'fibroid_validate_temp7': 15,
'weakpolyp_fibroid_train_temp7': 85 ,
'fibroid_train_temp8': 83,
'fibroid_validate_temp8': 17,
'weakpolyp_fibroid_train_temp8': 83
}
SET_NAME_TO_MODE = {
'fibroid_train': 'train',
'fibroid_validate': 'evaluate',
'weakpolyp_train': 'train',
'fibroid_train_temp7': 'train',
'fibroid_validate_temp7': 'evaluate',
'weakpolyp_fibroid_train_temp7': 'train',
'fibroid_train_temp8': 'train',
'fibroid_validate_temp8': 'evaluate',
'weakpolyp_fibroid_train_temp8': 'train'
}
SET_NAME_TO_PREFIX = {
'fibroid_train': 'fibroid_train',
'fibroid_validate': 'fibroid_validate',
'weakpolyp_train': 'weakpolyp_fibroid_train',
'fibroid_train_temp7': 'fibroid_train_temp7',
'fibroid_validate_temp7': 'fibroid_validate_temp7',
'weakpolyp_fibroid_train_temp7': 'weakpolyp_fibroid_train_temp7' ,
'fibroid_train_temp8': 'fibroid_train_temp8',
'fibroid_validate_temp8': 'fibroid_validate_temp8',
'weakpolyp_fibroid_train_temp8': 'weakpolyp_fibroid_train_temp8'
}
SET_NAME_TO_GT_TYPE = {
'fibroid_train': 'GT',
'fibroid_validate': 'GT',
'weakpolyp_train': 'Box',
'fibroid_train_temp7': 'GT',
'fibroid_validate_temp7': 'GT',
'weakpolyp_fibroid_train_temp7': 'Box',
'fibroid_train_temp8': 'GT',
'fibroid_validate_temp8': 'GT',
'weakpolyp_fibroid_train_temp8': 'Box'
}
================================================
FILE: data_schedule/vis/fibroid/metrics.py
================================================
import warnings
from typing import Optional, List, Tuple, Union
import torch
"""Various metrics based on Type I and Type II errors.
References:
https://en.wikipedia.org/wiki/Confusion_matrix
Example:
.. code-block:: python
import segmentation_models_pytorch as smp
# lets assume we have multilabel prediction for 3 classes
output = torch.rand([10, 3, 256, 256])
target = torch.rand([10, 3, 256, 256]).round().long()
# first compute statistics for true positives, false positives, false negative and
# true negative "pixels"
tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
# then compute metrics with required reduction (see metric docs)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
"""
__all__ = [
"get_stats",
"fbeta_score",
"f1_score",
"iou_score",
"accuracy",
"precision",
"recall",
"sensitivity",
"specificity",
"balanced_accuracy",
"positive_predictive_value",
"negative_predictive_value",
"false_negative_rate",
"false_positive_rate",
"false_discovery_rate",
"false_omission_rate",
"positive_likelihood_ratio",
"negative_likelihood_ratio",
]
###################################################################################################
# Statistics computation (true positives, false positives, false negatives, false positives)
###################################################################################################
def get_stats(
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
mode: str,
ignore_index: Optional[int] = None,
threshold: Optional[Union[float, List[float]]] = None,
num_classes: Optional[int] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
"""Compute true positive, false positive, false negative, true negative 'pixels'
for each image and each class.
Args:
output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following
shapes and types depending on the specified ``mode``:
'binary'
shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
'multilabel'
shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
'multiclass'
shape (N, ...) and ``torch.LongTensor``
target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``:
'binary'
shape (N, 1, ...)
'multilabel'
shape (N, C, ...)
'multiclass'
shape (N, ...)
mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'``
ignore_index (Optional[int]): Label to ignore on for metric computation.
**Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None.
threshold (Optional[float, List[float]]): Binarization threshold for
``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
num_classes (Optional[int]): Number of classes, necessary attribute
only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1).
If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or
``255``.
Raises:
ValueError: in case of misconfiguration.
Returns:
Tuple[torch.LongTensor]: true_positive, false_positive, false_negative,
true_negative tensors (N, C) shape each.
"""
if torch.is_floating_point(target):
raise ValueError(f"Target should be one of the integer types, got {target.dtype}.")
if torch.is_floating_point(output) and threshold is None:
raise ValueError(
f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}."
)
if torch.is_floating_point(output) and mode == "multiclass":
raise ValueError(f"For ``multiclass`` mode ``output`` should be one of the integer types, got {output.dtype}.")
if mode not in {"binary", "multiclass", "multilabel"}:
raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.")
if mode == "multiclass" and threshold is not None:
raise ValueError("``threshold`` parameter does not supported for this 'multiclass' mode")
if output.shape != target.shape:
raise ValueError(
"Dimensions should match, but ``output`` shape is not equal to ``target`` "
+ f"shape, {output.shape} != {target.shape}"
)
if mode != "multiclass" and ignore_index is not None:
raise ValueError(f"``ignore_index`` parameter is not supproted for '{mode}' mode")
if mode == "multiclass" and num_classes is None:
raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.")
if ignore_index is not None and 0 <= ignore_index <= num_classes - 1:
raise ValueError(
f"``ignore_index`` should be outside the class values range, but got class values in range "
f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``"
f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``"
f"and relevant class values started from ``0``."
)
if mode == "multiclass":
tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes, ignore_index)
else:
if threshold is not None:
output = torch.where(output >= threshold, 1, 0)
target = torch.where(target >= threshold, 1, 0)
tp, fp, fn, tn = _get_stats_multilabel(output, target)
return tp, fp, fn, tn
@torch.no_grad()
def _get_stats_multiclass(
output: torch.LongTensor,
target: torch.LongTensor,
num_classes: int,
ignore_index: Optional[int],
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
if ignore_index is not None:
ignore = target == ignore_index
output = torch.where(ignore, -1, output)
target = torch.where(ignore, -1, target)
ignore_per_sample = ignore.view(batch_size, -1).sum(1)
tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
for i in range(batch_size):
target_i = target[i]
output_i = output[i]
mask = output_i == target_i
matched = torch.where(mask, target_i, -1)
tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1)
fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
tn = num_elements - tp - fp - fn
if ignore_index is not None:
tn = tn - ignore_per_sample[i]
tp_count[i] = tp.long()
fp_count[i] = fp.long()
fn_count[i] = fn.long()
tn_count[i] = tn.long()
return tp_count, fp_count, fn_count, tn_count
@torch.no_grad()
def _get_stats_multilabel(
output: torch.LongTensor,
target: torch.LongTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
batch_size, num_classes, *dims = target.shape
output = output.view(batch_size, num_classes, -1)
target = target.view(batch_size, num_classes, -1)
tp = (output * target).sum(2)
fp = output.sum(2) - tp
fn = target.sum(2) - tp
tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn)
return tp, fp, fn, tn
###################################################################################################
# Metrics computation
###################################################################################################
def _handle_zero_division(x, zero_division):
nans = torch.isnan(x)
if torch.any(nans) and zero_division == "warn":
warnings.warn("Zero division in metric calculation!")
value = zero_division if zero_division != "warn" else 0
value = torch.tensor(value, dtype=x.dtype).to(x.device)
x = torch.where(nans, value, x)
return x
def _compute_metric(
metric_fn,
tp,
fp,
fn,
tn,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division="warn",
**metric_kwargs,
) -> float:
if class_weights is None and reduction is not None and "weighted" in reduction:
raise ValueError(f"Class weights should be provided for `{reduction}` reduction")
class_weights = class_weights if class_weights is not None else 1.0
class_weights = torch.tensor(class_weights).to(tp.device)
class_weights = class_weights / class_weights.sum()
if reduction == "micro":
tp = tp.sum()
fp = fp.sum()
fn = fn.sum()
tn = tn.sum()
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
elif reduction == "macro":
tp = tp.sum(0)
fp = fp.sum(0)
fn = fn.sum(0)
tn = tn.sum(0)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score * class_weights).mean()
elif reduction == "weighted":
tp = tp.sum(0)
fp = fp.sum(0)
fn = fn.sum(0)
tn = tn.sum(0)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score * class_weights).sum()
elif reduction == "micro-imagewise":
tp = tp.sum(1)
fp = fp.sum(1)
fn = fn.sum(1)
tn = tn.sum(1)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = score.mean()
elif reduction == "macro-imagewise" or reduction == "weighted-imagewise":
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score.mean(0) * class_weights).mean()
elif reduction == "none" or reduction is None:
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
else:
raise ValueError(
"`reduction` should be in [micro, macro, weighted, micro-imagewise,"
+ "macro-imagesize, weighted-imagewise, none, None]"
)
return score
# Logic for metric computation, all metrics are with the same interface
def _fbeta_score(tp, fp, fn, tn, beta=1):
beta_tp = (1 + beta**2) * tp
beta_fn = (beta**2) * fn
score = beta_tp / (beta_tp + beta_fn + fp)
return score
def _iou_score(tp, fp, fn, tn):
return tp / (tp + fp + fn)
def _accuracy(tp, fp, fn, tn):
return (tp + tn) / (tp + fp + fn + tn)
def _sensitivity(tp, fp, fn, tn):
return tp / (tp + fn)
def _specificity(tp, fp, fn, tn):
return tn / (tn + fp)
def _balanced_accuracy(tp, fp, fn, tn):
return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2
def _dice(tp, fp, fn, tn):
return tp * 2 / (tp + fp + fn + tp)
def _positive_predictive_value(tp, fp, fn, tn):
return tp / (tp + fp)
def _negative_predictive_value(tp, fp, fn, tn):
return tn / (tn + fn)
def _false_negative_rate(tp, fp, fn, tn):
return fn / (fn + tp)
def _false_positive_rate(tp, fp, fn, tn):
return fp / (fp + tn)
def _false_discovery_rate(tp, fp, fn, tn):
return 1 - _positive_predictive_value(tp, fp, fn, tn)
def _false_omission_rate(tp, fp, fn, tn):
return 1 - _negative_predictive_value(tp, fp, fn, tn)
def _positive_likelihood_ratio(tp, fp, fn, tn):
return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn)
def _negative_likelihood_ratio(tp, fp, fn, tn):
return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn)
def fbeta_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
beta: float = 1.0,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""F beta score"""
return _compute_metric(
_fbeta_score,
tp,
fp,
fn,
tn,
beta=beta,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def f1_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""F1 score"""
return _compute_metric(
_fbeta_score,
tp,
fp,
fn,
tn,
beta=1.0,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def dice(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""IoU score or Jaccard index""" # noqa
return _compute_metric(
_dice,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def iou_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""IoU score or Jaccard index""" # noqa
return _compute_metric(
_iou_score,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def accuracy(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Accuracy"""
return _compute_metric(
_accuracy,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def sensitivity(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Sensitivity, recall, hit rate, or true positive rate (TPR)"""
return _compute_metric(
_sensitivity,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def specificity(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Specificity, selectivity or true negative rate (TNR)"""
return _compute_metric(
_specificity,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def balanced_accuracy(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Balanced accuracy"""
return _compute_metric(
_balanced_accuracy,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def positive_predictive_value(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Precision or positive predictive value (PPV)"""
return _compute_metric(
_positive_predictive_value,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def negative_predictive_value(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Negative predictive value (NPV)"""
return _compute_metric(
_negative_predictive_value,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_negative_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Miss rate or false negative rate (FNR)"""
return _compute_metric(
_false_negative_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_positive_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Fall-out or false positive rate (FPR)"""
return _compute_metric(
_false_positive_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_discovery_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""False discovery rate (FDR)""" # noqa
return _compute_metric(
_false_discovery_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_omission_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""False omission rate (FOR)""" # noqa
return _compute_metric(
_false_omission_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def positive_likelihood_ratio(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Positive likelihood ratio (LR+)"""
return _compute_metric(
_positive_likelihood_ratio,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def negative_likelihood_ratio(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Negative likelihood ratio (LR-)"""
return _compute_metric(
_negative_likelihood_ratio,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
_doc = """
Args:
tp (torch.LongTensor): tensor of shape (N, C), true positive cases
fp (torch.LongTensor): tensor of shape (N, C), false positive cases
fn (torch.LongTensor): tensor of shape (N, C), false negative cases
tn (torch.LongTensor): tensor of shape (N, C), true negative cases
reduction (Optional[str]): Define how to aggregate metric between classes and images:
- 'micro'
Sum true positive, false positive, false negative and true negative pixels over
all images and all classes and then compute score.
- 'macro'
Sum true positive, false positive, false negative and true negative pixels over
all images for each label, then compute score for each label separately and average labels scores.
This does not take label imbalance into account.
- 'weighted'
Sum true positive, false positive, false negative and true negative pixels over
all images for each label, then compute score for each label separately and average
weighted labels scores.
- 'micro-imagewise'
Sum true positive, false positive, false negative and true negative pixels for **each image**,
then compute score for **each image** and average scores over dataset. All images contribute equally
to final score, however takes into accout class imbalance for each image.
- 'macro-imagewise'
Compute score for each image and for each class on that image separately, then compute average score
on each image over labels and average image scores over dataset. Does not take into account label
imbalance on each image.
- 'weighted-imagewise'
Compute score for each image and for each class on that image separately, then compute weighted average
score on each image over labels and average image scores over dataset.
- 'none' or ``None``
Same as ``'macro-imagewise'``, but without any reduction.
For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and
``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``.
Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated,
while postfix ``'imagewise'`` defines how scores between the images will be aggregated.
class_weights (Optional[List[float]]): list of class weights for metric
aggregation, in case of `weighted*` reduction is chosen. Defaults to None.
zero_division (Union[str, float]): Sets the value to return when there is a zero division,
i.e. when all predictions and labels are negative. If set to “warn”, this acts as 0,
but warnings are also raised. Defaults to 1.
Returns:
torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric,
else returns tensor of shape (N, C)
References:
https://en.wikipedia.org/wiki/Confusion_matrix
"""
fbeta_score.__doc__ += _doc
f1_score.__doc__ += _doc
iou_score.__doc__ += _doc
accuracy.__doc__ += _doc
sensitivity.__doc__ += _doc
specificity.__doc__ += _doc
balanced_accuracy.__doc__ += _doc
positive_predictive_value.__doc__ += _doc
negative_predictive_value.__doc__ += _doc
false_negative_rate.__doc__ += _doc
false_positive_rate.__doc__ += _doc
false_discovery_rate.__doc__ += _doc
false_omission_rate.__doc__ += _doc
positive_likelihood_ratio.__doc__ += _doc
negative_likelihood_ratio.__doc__ += _doc
precision = positive_predictive_value
recall = sensitivity
================================================
FILE: data_schedule/vis/mapper.py
================================================
import json
import os
from typing import List
import copy
from functools import partial
import random
import numpy as np
import torch
import logging
from einops import rearrange
from detectron2.data import MetadataCatalog
from data_schedule.registry import MAPPER_REGISTRY
from .mapper_utils import VIS_TrainMapper, VIS_EvalMapper
from .vis_frame_sampler import VIS_FRAMES_SAMPLER_REGISTRY
from data_schedule.vis.apis import VIS_Dataset, VIS_Aug_CallbackAPI,\
VIS_TrainAPI_clipped_video, VIS_EvalAPI_clipped_video_request_ann
@MAPPER_REGISTRY.register()
class VIS_Video_EvalMapper(VIS_EvalMapper):
def __init__(self,
configs,
dataset_name,
mode,
meta_idx_shift,
):
assert mode == 'evaluate'
dataset_meta = MetadataCatalog.get(dataset_name)
assert dataset_meta.get('step_size') == None
mapper_config = configs['data'][mode][dataset_name]['mapper']
super().__init__(meta_idx_shift=meta_idx_shift,
dataset_meta=dataset_meta,
mapper_config=mapper_config)
def _call(self, data_dict):
VIS_Dataset
video_id, all_frames = data_dict['video_id'], data_dict['all_frames']
video_frames = self.get_frames_fn(video_id=video_id, frames=all_frames)
aug_ret = {
'video': video_frames,
'callback_fns': []
}
VIS_Aug_CallbackAPI
aug_ret = self.augmentation(aug_ret)
video = aug_ret.pop('video')
callback_fns = aug_ret.pop('callback_fns')[::-1]
VIS_EvalAPI_clipped_video_request_ann
return {
'video_dict': {'video': video},
'meta': {
'video_id': video_id,
'frames': all_frames,
'request_ann': torch.ones(len(all_frames)).bool(),
'callback_fns': callback_fns
}
}
@MAPPER_REGISTRY.register()
class VIS_Video_or_Step_To_Clip_TrainMapper(VIS_TrainMapper):
def __init__(self,
dataset_name,
configs,
mode,
meta_idx_shift,
):
assert mode == 'train'
dataset_meta = MetadataCatalog.get(dataset_name)
assert dataset_meta.get('name') == dataset_name
mapper_config = configs['data'][mode][dataset_name]['mapper']
super().__init__(meta_idx_shift=meta_idx_shift,
dataset_meta=dataset_meta,
mapper_config=mapper_config)
self.frames_sampler = VIS_FRAMES_SAMPLER_REGISTRY.get(\
mapper_config['frames_sampler']['name'])(sampler_configs=mapper_config['frames_sampler'],
dataset_meta=dataset_meta)
def _call(self, data_dict):
VIS_Dataset
video_id, all_frames, all_objs = data_dict['video_id'], data_dict['all_frames'], data_dict['all_objs']
frame_idx = data_dict['frame_idx'] if 'frame_idx' in data_dict else None
all_obj_ids = list(all_objs.keys()) # [1, 2, 5, 4]
assert len(list(set(all_obj_ids))) == len(all_obj_ids)
class_labels = torch.tensor([all_objs[key]['class_label'] for key in all_obj_ids]) # [8, 10, 20 34]
re_sample = True
sampled_counts = 0
while re_sample:
sampled_frames = self.frames_sampler(all_frames=all_frames, frame_idx=frame_idx, video_id=video_id)
# t' h w, int, obj_ids ; has_ann t
frames_mask, has_ann = self.get_frames_mask_fn(video_id=video_id, frames=sampled_frames)
appear_objs = frames_mask.unique() # [0, 1, 2]
assert set(appear_objs.tolist()).issubset(set([0] + all_obj_ids))
re_sample = (len(list(set(appear_objs.tolist()) & set(all_obj_ids))) == 0)
# 只要出现某些个物体就行
sampled_counts += 1
if sampled_counts > 2:
logging.error('sampled two much times')
raise RuntimeError()
frames_mask = torch.stack([frames_mask == obj_id for obj_id in all_obj_ids], dim=0) # N t' h w, bool
video_frames = self.get_frames_fn(video_id=video_id, frames=sampled_frames)
width, height = video_frames[0].size
aug_ret = {
'video': video_frames,
'masks': frames_mask, # N t' h w
'has_ann': has_ann, # t
'classes': class_labels, # N
}
VIS_Aug_CallbackAPI
aug_ret = self.augmentation(aug_ret)
video = aug_ret.pop('video')
frame_targets = self.map_to_frame_targets(aug_ret)
if self.clip_global_targets_map_to_local_targets:
aug_ret = self.map_global_targets_to_local_targets(aug_ret)
VIS_TrainAPI_clipped_video
ret = {}
ret['video_dict'] = {'video': video}
ret['targets'] = aug_ret
ret['frame_targets'] = frame_targets
return ret
================================================
FILE: data_schedule/vis/mapper_utils.py
================================================
from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY, VIS_TRAIN_AUG_REGISTRY
import torch
from copy import deepcopy as dcopy
from data_schedule.registry import Mapper
import copy
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video
class VIS_Mapper(Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,) -> None:
super().__init__(meta_idx_shift=meta_idx_shift, dataset_meta=dataset_meta)
self.get_frames_fn = dataset_meta.get('get_frames_fn')
class VIS_TrainMapper(VIS_Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,
mapper_config) -> None:
super().__init__(meta_idx_shift, dataset_meta)
self.get_frames_mask_fn = dataset_meta.get('get_frames_mask_fn')
self.clip_global_targets_map_to_local_targets = mapper_config['clip_global_targets_map_to_local_targets']
self.augmentation = VIS_TRAIN_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation'])
def map_to_frame_targets(self, clip_targets):
VIS_TrainAPI_clipped_video
clip_rets = copy.deepcopy(clip_targets)
masks = clip_rets['masks'].transpose(0, 1).contiguous() # t' N h w
class_labels = clip_rets['classes'] # [10, 32, 10, 4]
has_box = 'boxes' in clip_rets
if has_box:
boxes = clip_rets['boxes'].transpose(0, 1).contiguous() # t' N 4
assert len(masks) == len(boxes)
ret = []
for idx, frame_mk in enumerate(masks):
frame_targets = {
'masks': frame_mk.unsqueeze(1), # N 1 h w
'classes': class_labels, # N
}
if has_box:
frame_targets.update({'boxes': boxes[idx].unsqueeze(1)}) # N 1 4
if self.clip_global_targets_map_to_local_targets:
frame_targets = self.map_global_targets_to_local_targets(frame_targets)
frame_targets['masks'] = frame_targets['masks'].squeeze(1)
if has_box:
frame_targets['boxes'] = frame_targets['boxes'].squeeze(1)
ret.append(frame_targets)
return ret
def map_global_targets_to_local_targets(self, ret):
VIS_TrainAPI_clipped_video
masks = ret['masks'] # N t' h w
global_obj_appear = masks.flatten(1).any(-1) # N [True, False, True, False, False, False, True]
ret['masks'] = ret['masks'][global_obj_appear]
ret['classes'] = ret['classes'][global_obj_appear]
if 'boxes' in ret:
ret['boxes'] = ret['boxes'][global_obj_appear] # n t' 4
return ret
class VIS_EvalMapper(VIS_Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,
mapper_config) -> None:
super().__init__(meta_idx_shift, dataset_meta)
assert mapper_config['augmentation']['name'] in ['WeakPolyP_EvalAug', 'Visha_EvalAug']
self.augmentation = VIS_EVAL_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation'])
================================================
FILE: data_schedule/vis/polyp/__init__.py
================================================
from . import polyp_dataset
from . import evals
================================================
FILE: data_schedule/vis/polyp/evals.py
================================================
from data_schedule.vis.evaluator_utils import register_vis_metric
import os
import torch
import detectron2.utils.comm as comm
import logging
import subprocess
@register_vis_metric
def polyp_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs):
# output: eval_metrics
# video: frame_name: metric/ vid_metrics
eval_metrics = {}
# video, frame_name
# perframe metrics
metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]]
for taylor_swift in metric_names:
eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean()
# metrics by each video
mean_iou_by_each_video = {}
mean_dice_by_each_video = {}
for billie_eilish in eval_meta_keys:
mean_iou_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['iou'] for fname in eval_meta_keys[billie_eilish]]).mean()
mean_dice_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['dice'] for fname in eval_meta_keys[billie_eilish]]).mean()
mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1]))
mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1]))
logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}')
logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}')
return eval_metrics
================================================
FILE: data_schedule/vis/polyp/polyp_dataset.py
================================================
from typing import Optional, Union
import json
import os
from functools import partial
import numpy as np
import torch
import logging
from tqdm import tqdm
import copy
from detectron2.data import DatasetCatalog, MetadataCatalog
from collections import defaultdict
from .polyp_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR, SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX
def polyp_train(step_size,
split_dataset_name,
video_ids,
video_to_frames,
root_path):
logging.debug(f'{split_dataset_name} Generating metas...')
metas = []
for vid_id in tqdm(video_ids):
all_frames = sorted(video_to_frames[vid_id])
poly_class = 0
if step_size is None:
metas.append({
'video_id': vid_id,
'all_frames' : all_frames,
'all_objs': { 1: {'class_label': poly_class} },
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': vid_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'all_objs': { 1: {'class_label': poly_class} },
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
def polyp_evaluate(eval_video_ids,
split_dataset_name,
step_size,
video_to_frames):
if (step_size is not None) and (step_size > 1):
logging.warning('why?')
raise ValueError()
metas = []
for video_id in eval_video_ids:
all_frames = sorted(video_to_frames[video_id])
if step_size == None:
metas.append({
'video_id': video_id,
'all_frames': all_frames,
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': video_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
_root = os.getenv('DATASET_PATH')
root = os.path.join(_root, 'SUN/SUN-SEG2')
visualize_meta_idxs = defaultdict(list)
visualize_meta_idxs['polyp_train_step[6]'] = []
visualize_meta_idxs['polyp_train'] = []
visualize_meta_idxs['polyp_hard_unseen'] = []
visualize_meta_idxs['polyp_hard_seen'] = []
visualize_meta_idxs['polyp_easy_unseen'] = []
visualize_meta_idxs['polyp_easy_seen'] = []
polyp_meta = {
'thing_classes': ['polyp', 'not polyp'],
'thing_colors': [(255., 140., 0.), (0., 255., 0.)],
'root': root
}
for name in SET_NAME:
set_dir = SET_NAME_TO_DIR[name]
set_dir = os.path.join(root, set_dir)
num_videos = SET_NAME_TO_NUM_VIDEOS[name]
video_ids = os.listdir(os.path.join(set_dir, 'Frame'))
assert len(video_ids) == num_videos
video_to_frames = {
vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.jpg')])\
for vid in video_ids
}
mode = SET_NAME_TO_MODE[name]
if mode == 'train':
prefix = SET_NAME_TO_PREFIX[name]
train_meta = copy.deepcopy(polyp_meta)
train_meta.update({
'mode': 'train',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')),
'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, 'GT'),),
})
# train
for step_size in [1, 3, 6, 9, 12, None]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
train_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(polyp_train,
video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames,
root_path=set_dir))
MetadataCatalog.get(split_name).set(**train_meta,
step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
elif mode == 'evaluate':
prefix = SET_NAME_TO_PREFIX[name]
validate_meta = copy.deepcopy(polyp_meta)
validate_meta.update({
'mode': 'evaluate',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))),
'eval_set_name': SET_NAME_TO_DIR[name],
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
'eval_meta_keys': video_to_frames
})
for step_size in [1, None,]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
validate_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(polyp_evaluate,
eval_video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames))
MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
================================================
FILE: data_schedule/vis/polyp/polyp_utils.py
================================================
import os
import numpy as np
import torch
from PIL import Image
SET_NAME = ['polyp_train',
'polyp_hard_seen_validate',
'polyp_hard_unseen_validate',
'polyp_easy_seen_validate',
'polyp_easy_unseen_validate',
'polyp_hard_validate',
'polyp_easy_validate',
'Kvasir-train',
'Mayo-train',
'300-train',
'612-train',
'300-tv',
'612-test',
'612-val'
]
SET_NAME_TO_DIR = {
'polyp_train': 'TrainDataset',
'polyp_hard_seen_validate': 'TestHardDataset/Seen',
'polyp_hard_unseen_validate': 'TestHardDataset/Unseen',
'polyp_easy_seen_validate': 'TestEasyDataset/Seen',
'polyp_easy_unseen_validate': 'TestEasyDataset/Unseen',
'polyp_hard_validate': 'TestHardDataset/Combine',
'polyp_easy_validate': 'TestEasyDataset/Combine',
'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG',
'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
'300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
'612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train',
'300-tv': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ColonDB-300',
'612-test': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Test',
'612-val': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Valid'
}
SET_NAME_TO_NUM_VIDEOS = {
'polyp_train': 112,
'polyp_hard_seen_validate': 17,
'polyp_hard_unseen_validate': 37,
'polyp_easy_seen_validate': 33,
'polyp_easy_unseen_validate': 86,
'polyp_hard_validate': 54,
'polyp_easy_validate': 119,
'Kvasir-train': 1,
'Mayo-train': 10,
'300-train': 6,
'612-train': 18,
'300-tv': 6,
'612-test': 5,
'612-val': 5
}
SET_NAME_TO_MODE = {
'polyp_train': 'train',
'polyp_hard_seen_validate': 'evaluate',
'polyp_hard_unseen_validate': 'evaluate',
'polyp_easy_seen_validate': 'evaluate',
'polyp_easy_unseen_validate': 'evaluate',
'polyp_hard_validate': 'evaluate',
'polyp_easy_validate': 'evaluate',
'Kvasir-train': 'train',
'Mayo-train': 'train',
'300-train': 'train',
'612-train': 'train',
'300-tv': 'evaluate',
'612-test': 'evaluate',
'612-val': 'evaluate'
}
SET_NAME_TO_PREFIX = {
'polyp_train': 'polyp_train',
'polyp_hard_seen_validate': 'polyp_hard_seen_validate',
'polyp_hard_unseen_validate': 'polyp_hard_unseen_validate',
'polyp_easy_seen_validate': 'polyp_easy_seen_validate',
'polyp_easy_unseen_validate': 'polyp_easy_unseen_validate',
'polyp_hard_validate': 'polyp_hard_validate',
'polyp_easy_validate': 'polyp_easy_validate',
'Kvasir-train': 'Kvasir-train',
'Mayo-train': 'Mayo-train',
'300-train': '300-train',
'612-train': '612-train',
'300-tv': '300-tv',
'612-test': '612-test',
'612-val': '612-val'
}
CLASS_TO_ID = {
'high_grade_adenoma':0,
'hyperplastic_polyp':1,
'invasive_cancer':2,
'low_grade_adenoma':3,
'sessile_serrated_lesion':4,
'traditional_serrated_adenoma':5
}
def get_frames(frames_path, video_id, frames):
return [Image.open(os.path.join(frames_path, video_id, f'{f}.jpg')).convert('RGB') for f in frames]
def get_frames_mask(mask_path, video_id, frames):
# masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames]
if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames]
else:
raise ValueError()
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
# assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}'
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
================================================
FILE: data_schedule/vis/vis_aug_eval.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
import torch
import torchvision.transforms.functional as F
from data_schedule.vis.apis import VIS_Aug_CallbackAPI
from .vis_aug_utils import get_tgt_size
from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY
class RandomResize:
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, ret):
video = ret['video']
orig_size = video[0].size # w h
tgt_size = get_tgt_size(video[0].size, random.choice(self.sizes), self.max_size) # h w
resized_video = [F.resize(frame, tgt_size) for frame in video]
ratio_width, ratio_height = tuple(float(s) / float(s_orig) for s, s_orig in zip(tgt_size[::-1], orig_size))
ret['video'] = resized_video
if 'callback_fns' in ret:
VIS_Aug_CallbackAPI
ret['callback_fns'].append(RandomResize(sizes=[orig_size], max_size=None))
if "pred_masks" in ret:
assert (len(self.sizes) == 1) and (self.max_size == None)
VIS_Aug_CallbackAPI
pred_masks = ret['pred_masks'] # list[nt h w], t
pred_masks = [torch.nn.functional.interpolate(mk.unsqueeze(0).float(), tgt_size, mode='nearest')[0].bool()
for mk in pred_masks]
ret['pred_masks'] = pred_masks # list[nt h w], t
if "pred_boxes" in ret:
VIS_Aug_CallbackAPI
pred_boxes = ret["pred_boxes"] # list[nt 4], t
scaled_boxes = [bx * (torch.tensor([ratio_width, ratio_height, ratio_width, ratio_height])[None, :])
for bx in pred_boxes]
ret["pred_boxes"] = scaled_boxes
return ret
class VideoToPIL:
def __call__(self, ret):
video = ret['video'] # t 3 h w ->
assert video.dtype == torch.float and (video.max() <= 1) and (video.min() >=0)
pil_video = [F.to_pil_image(frame, mode='RGB') for frame in video] # 3 h w, float, 0-1
ret['video'] = pil_video
assert 'callback_fns' not in ret
return ret
class VideoToTensor:
def __call__(self, ret):
video = ret['video']
tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1
ret['video'] = tensor_video
if 'callback_fns' in ret:
VIS_Aug_CallbackAPI
ret['callback_fns'].append(VideoToPIL())
return ret
@VIS_EVAL_AUG_REGISTRY.register()
class WeakPolyP_EvalAug:
def __init__(self, configs) -> None:
self.resize = RandomResize(
sizes=[[352, 352]],
)
self.tensor_video = VideoToTensor()
def __call__(self, ret):
VIS_Aug_CallbackAPI
ret = self.resize(ret)
ret = self.tensor_video(ret)
return ret
================================================
FILE: data_schedule/vis/vis_aug_train.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
from PIL import Image
import torch
import torchvision.transforms.functional as F
from einops import rearrange
from copy import deepcopy as dcopy
from data_schedule.vis.apis import VIS_Aug_CallbackAPI
import albumentations as A
import numpy as np
from data_schedule.utils.segmentation import bounding_box_from_mask
from .vis_aug_utils import VIS_TRAIN_AUG_REGISTRY, pil_torch_to_numpy, numpy_to_pil_torch
import copy
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
import imgaug
from datetime import datetime
class RandomRotate90:
def __init__(self) -> None:
self.album_aug = A.ReplayCompose(
[A.RandomRotate90(0.5)]
)
def __call__(self, ret):
video = ret['video']
masks = ret['masks']
has_ann = ret['has_ann']
# list[PIL], n t' h w ->
# list[h w 3, 255rgb], t
# list[list[h w, 01uint8]] t
video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann)
replay = self.album_aug(image=video[0], mask=[masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(video, masks):
ret = self.album_aug.replay(replay, image=vid, mask=mk)
auged_video.append(ret['image'])
auged_mask.append(ret['mask'])
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, auged_mask=auged_mask, has_ann=has_ann)
ret['video'] = auged_video
ret['mask'] = auged_mask
return ret
class ComputeBox:
def __call__(self, ret):
W, H = ret['video'][0].size
N, T = ret['masks'].shape[:2] # n t' h w
boxes = torch.stack([bounding_box_from_mask(mask) for mask in copy.deepcopy(ret['masks']).flatten(0, 1)], dim=0) # Nt' 4
boxes = rearrange(boxes, '(N T) c -> N T c', N=N, T=T)
boxes[:, :, 0::2].clamp_(min=0, max=W)
boxes[:, :, 1::2].clamp_(min=0, max=H)
ret['boxes'] = boxes
return ret
class VideoToTensor:
def __call__(self, ret):
video = ret['video']
tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1
ret['video'] = tensor_video
return ret
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, ret):
for t in self.transforms:
ret = t(ret)
return ret
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
@VIS_TRAIN_AUG_REGISTRY.register()
class WeakPolyP_TrainAug:
def __init__(self, configs) -> None:
self.transform = A.ReplayCompose([
A.Resize(352, 352),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
])
self.tensor_video = VideoToTensor()
self.add_box = ComputeBox()
def __call__(self, ret):
VIS_Aug_CallbackAPI
video = ret['video']
masks = ret['masks'] # n t' h w
has_ann = ret['has_ann'] # t
# list[PIL] -> list[h w 3, 0-1float], t
# n t' h w -> list[list[h w, 01uint8], 没有annotation的帧box是空] t
video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann)
replay = self.transform(image=video[0], masks=[masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(video, masks):
auged_each_frame = self.transform.replay(replay, image=vid, masks=mk)
auged_video.append(auged_each_frame['image'])
auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8]
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann)
ret['video'] = auged_video
ret['masks'] = auged_mask
ret = self.add_box(ret)
ret = self.tensor_video(ret)
return ret
@VIS_TRAIN_AUG_REGISTRY.register()
class WeakPolyP_TrainAug_RotateImageToClip:
def __init__(self, configs) -> None:
self.ImageToSeqAugmenter = ImageToSeqAugmenter(perspective=True, affine=True, motion_blur=True,
rotation_range=(-20, 20), perspective_magnitude=0.08,
hue_saturation_range=(-5, 5), brightness_range=(-40, 40),
motion_blur_prob=0.25, motion_blur_kernel_sizes=(9, 11),
translate_range=(-0.1, 0.1))
self.num_frames = configs['num_frames']
self.transform = A.ReplayCompose([
A.Resize(352, 352),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
])
self.tensor_video = VideoToTensor()
self.add_box = ComputeBox()
def apply_random_sequence_shuffle(self, images, instance_masks):
perm = list(range(self.num_frames))
random.shuffle(perm)
images = [images[i] for i in perm]
instance_masks = [instance_masks[i] for i in perm]
return images, instance_masks
def __call__(self, ret):
VIS_Aug_CallbackAPI
video = ret['video'] # list[pil], t
masks = ret['masks'] # n t' h w
has_ann = ret['has_ann'] # t
# list[PIL] -> list[h w 3, uint8], t
# n t' h w -> list[list[h w], n, uint8], t
seq_images, seq_instance_masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann, float_image=False)
assert len(seq_images) == 1 and len(seq_instance_masks) == 1
static_img, static_mask = seq_images[0], seq_instance_masks[0]
for t in range(self.num_frames - 1):
im_trafo, instance_masks_trafo = self.ImageToSeqAugmenter(static_img, static_mask) # h w 3, uint8; list[h w], n, uint8
seq_images.append(np.uint8(im_trafo))
seq_instance_masks.append(instance_masks_trafo)
# list[h w 3], t ; # list[list[h w, 01uint8]] t
seq_images, seq_instance_masks = self.apply_random_sequence_shuffle(seq_images, seq_instance_masks)
has_ann = torch.ones(self.num_frames).bool() # T
seq_images = [np.float32(haosen) / 255.0 for haosen in seq_images] # list[h w 3, 0-1float], t
replay = self.transform(image=seq_images[0], masks=[seq_instance_masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(seq_images, seq_instance_masks):
auged_each_frame = self.transform.replay(replay, image=vid, masks=mk)
auged_video.append(auged_each_frame['image'])
auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8]
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann) # n t h w
# [haosen.save(f'./test{idx}.png') for idx, haosen in enumerate(auged_video)]
# import matplotlib.pyplot as plt
# [plt.imsave( f'./mask{idx}.png', auged_mask[0][idx].float().numpy()) for idx in range(len(auged_mask[0]))]
ret['video'] = auged_video
ret['masks'] = auged_mask
ret['has_ann'] = has_ann
ret = self.add_box(ret)
ret = self.tensor_video(ret)
return ret
class ImageToSeqAugmenter(object):
def __init__(self, perspective=True, affine=True, motion_blur=True,
brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12,
scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20),
motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5, seed=2024):
self.basic_augmenter = iaa.SomeOf((1, None), [
iaa.Add(brightness_range),
iaa.AddToHueAndSaturation(hue_saturation_range)
]
)
transforms = []
if perspective:
transforms.append(iaa.PerspectiveTransform(perspective_magnitude))
if affine:
transforms.append(iaa.Affine(scale=scale_range,
translate_percent=translate_range,
rotate=rotation_range,
order=1, # cv2.INTER_LINEAR
backend='auto'))
transforms = iaa.Sequential(transforms)
transforms = [transforms]
if motion_blur:
blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf(
[
iaa.MotionBlur(ksize)
for ksize in motion_blur_kernel_sizes
]
))
transforms.append(blur)
self.frame_shift_augmenter = iaa.Sequential(transforms)
self.seed = seed
@staticmethod
def condense_masks(instance_masks):
condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8)
for instance_id, mask in enumerate(instance_masks, 1):
condensed_mask = np.where(mask, instance_id, condensed_mask)
return condensed_mask
@staticmethod
def expand_masks(condensed_mask, num_instances):
return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)]
def __call__(self, image, masks=None, boxes=None): # n h w
det_augmenter = self.frame_shift_augmenter.to_deterministic()
if masks is not None:
masks_np, is_binary_mask = [], []
boxs_np = []
for mask in masks:
if isinstance(mask, np.ndarray):
masks_np.append(mask.astype(np.bool_))
is_binary_mask.append(False)
else:
raise ValueError("Invalid mask type: {}".format(type(mask)))
num_instances = len(masks_np)
masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2])
# boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2])
seed = int(datetime.now().strftime('%M%S%f')[-8:])
imgaug.seed(seed)
aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np)
imgaug.seed(seed)
invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2)
aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances)
# aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image()
aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)]
return aug_image, aug_masks #, aug_boxes.to_xyxy_array()
else:
masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])]
aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks)
return aug_image, invalid_pts_mask.get_arr() == 0
================================================
FILE: data_schedule/vis/vis_aug_utils.py
================================================
from detectron2.utils.registry import Registry
import torch
import numpy as np
import torchvision.transforms.functional as F
from PIL import Image
from einops import rearrange, reduce, repeat
VIS_EVAL_AUG_REGISTRY = Registry('VIS_EVAL_AUG')
VIS_TRAIN_AUG_REGISTRY = Registry('VIS_TRAIN_AUG')
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_tgt_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
def pil_torch_to_numpy(video, masks, has_ann, float_image=True):
# n t' h w
# list[pil_image, rgb], t
# t
N, T = masks.shape[:2]
has_ann_idx = torch.where(has_ann)[0] # time_idx
# list[Image], t -> list[h w 3, 255uint8], t
masks = masks.permute(1, 0, 2, 3).contiguous().unbind(0) # list[n h w] t'
numpy_masks = [[]] * len(has_ann) # list[list[h w, 01_uint8], n], t
assert len(has_ann_idx) == len(masks)
for fmask, taylor in zip(masks, has_ann_idx): # n h w
fnumpy_masks = []
for mk in fmask.unbind(0): # h w
fnumpy_masks.append(mk.numpy().astype(np.uint8))
numpy_masks[taylor] = fnumpy_masks
if float_image:
# list[h w 3, 0-1float], t
video = [F.to_tensor(frame).permute(1,2,0).numpy() for frame in video]
else:
# uint8
video = [np.array(frame) for frame in video]
return video, numpy_masks
def numpy_to_pil_torch(video, masks, has_ann):
# numpy, numpy -> torch, torch
# list[h w 3, 0-1float], t
H, W = video[0].shape[:2]
T = has_ann.int().sum()
video = [Image.fromarray(np.uint8(aug_vid * 255), mode="RGB") for aug_vid in video]
# t'n h w
torch_masks = torch.stack([torch.from_numpy(obj_mk).bool() for frame_mk in masks for obj_mk in frame_mk], dim=0)
torch_masks = rearrange(torch_masks, '(T N) h w -> N T h w', T=T) # n t' h w
return video, torch_masks
================================================
FILE: data_schedule/vis/vis_frame_sampler.py
================================================
from detectron2.utils.registry import Registry
import random
import numpy as np
import torch
import logging
from detectron2.utils import comm
VIS_FRAMES_SAMPLER_REGISTRY = Registry('VIS_FRAMES_SAMPLER')
import random
@VIS_FRAMES_SAMPLER_REGISTRY.register()
class Naive_ReferenceFrame_FrameSampler:
def __init__(self, sampler_configs, dataset_meta, **kwargs):
self.reference_frame_step_size = dataset_meta.get('step_size')
self.clip_sizes = list(sampler_configs['clip_sizes']) # list[int]
self.clip_distribute = sampler_configs['clip_distribute'] # dense, sparse, local_global
self.clip_position = sampler_configs['clip_position'] # former, center, latter
if max(self.clip_sizes) > self.reference_frame_step_size:
if comm.is_main_process():
logging.warning('')
def __call__(self,
frame_idx=None,
all_frames=None, # list[str]
**kwargs):
random_clip_size = random.choice(self.clip_sizes)
video_len = len(all_frames)
sample_indx = [frame_idx]
if (self.clip_position == 'center') and (self.clip_distribute == 'local_global'):
if random_clip_size != 1:
sample_id_before = random.randint(1, 3)
sample_id_after = random.randint(1, 3)
local_indx = [max(0, frame_idx - sample_id_before), min(video_len - 1, frame_idx + sample_id_after)]
sample_indx.extend(local_indx)
if random_clip_size > 3:
all_inds = list(range(video_len))
global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):]
global_n = random_clip_size - len(sample_indx)
if len(global_inds) > global_n:
select_id = random.sample(range(len(global_inds)), global_n)
for s_id in select_id:
sample_indx.append(global_inds[s_id])
elif video_len >= global_n:
select_id = random.sample(range(video_len), global_n)
for s_id in select_id:
sample_indx.append(all_inds[s_id])
else:
select_id = random.sample(range(video_len), global_n - video_len) + list(range(video_len))
for s_id in select_id:
sample_indx.append(all_inds[s_id])
elif (self.clip_position == 'center') and (self.clip_distribute == 'dense'):
half_size = (random_clip_size - 1) // 2
sample_indx += list(range(frame_idx - half_size, frame_idx))
sample_indx += list(range(frame_idx+1, half_size + frame_idx + 1))
if len(sample_indx) < random_clip_size:
sample_indx = [min(sample_indx)] + sample_indx
assert len(sample_indx) == random_clip_size
sample_indx = torch.tensor(sample_indx)
sample_indx = sample_indx.clamp_(min=0, max=video_len-1)
sample_indx = sample_indx.tolist()
else:
raise ValueError()
sample_indx.sort()
sampled_frames = [all_frames[idx] for idx in sample_indx]
return sampled_frames
================================================
FILE: handle_vps.py
================================================
import cv2
import numpy as np
import os
import shutil
from PIL import Image
import torch
from tqdm import tqdm
dataset_root = os.getenv('DATASET_PATH')
# the original IVPS is the union of Kvasir and per-frame Mayo/CVC
all_images = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame')
ka_images = [b for b in all_images if b.startswith('K')]
assert len(ka_images) == 1000
all_gts = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT')
ka_gts = [b for b in all_gts if b.startswith('K')]
assert len(ka_gts) == 1000
os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1'),exist_ok=True)
os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1'),exist_ok=True)
for image_id in tqdm(ka_images):
shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame', f'{image_id}'),
os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1', f'{image_id}'),)
for image_id in tqdm(ka_gts):
shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT', f'{image_id}'),
os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1', f'{image_id}'),)
# normalize train directory
for base_path in [f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train']:
video_ids = os.listdir(base_path)
frame_path = os.path.join(base_path, 'Frame')
gt_path = os.path.join(base_path, 'GT')
os.makedirs(frame_path, exist_ok=True)
os.makedirs(gt_path, exist_ok=True)
# Iterate through each video ID directory
for vid in video_ids:
shutil.copytree(os.path.join(base_path, vid, 'Frame'), os.path.join(frame_path, vid))
shutil.copytree(os.path.join(base_path, vid, 'GT'), os.path.join(gt_path, vid))
# TODO: dangerous: remove if you want
# remove non-mask frames of each training set
SET_NAME = [
'Kvasir-train',
'Mayo-train',
'300-train',
'612-train',
]
SET_NAME_TO_DIR = {
'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG',
'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
'300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
'612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train',
}
SET_NAME_TO_NUM_VIDEOS = {
'Kvasir-train': 1,
'Mayo-train': 10,
'300-train': 6,
'612-train': 18,
'300-tv': 6,
'612-test': 5,
'612-val': 5
}
SET_NAME_TO_PREFIX = {
'Kvasir-train': 'Kvasir-train',
'Mayo-train': 'Mayo-train',
'300-train': '300-train',
'612-train': '612-train',
}
root = os.getenv('DATASET_PATH')
def get_frames_mask(mask_path, video_id, frames):
# masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames]
if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames]
else:
raise ValueError()
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
# assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}'
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
num_delted_frames = 0
for train_set_name in SET_NAME:
set_dir = SET_NAME_TO_DIR[train_set_name]
frames_dir = os.path.join(root, set_dir, 'Frame')
mask_dir = os.path.join(root, set_dir, 'GT')
video_ids = os.listdir(frames_dir)
for vid in tqdm(video_ids):
frames = [haosen[:-4] for haosen in os.listdir(os.path.join(frames_dir, vid))]
frame_has_fore = [get_frames_mask(mask_dir, vid, [haosen])[0].any() for haosen in tqdm(frames)] # list[t]
assert len(frame_has_fore) == len(frames)
num_delted_frames += (~ torch.tensor(frame_has_fore)).int().sum()
for haosen, frame_name in tqdm(zip(frame_has_fore, frames)):
if not haosen:
os.remove(os.path.join(frames_dir, vid, f'{frame_name}.jpg'))
if os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.jpg')):
os.remove(os.path.join(mask_dir, vid, f'{frame_name}.jpg'))
elif os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.png')):
os.remove(os.path.join(mask_dir, vid, f'{frame_name}.png'))
else:
raise ValueError()
print(f'should be {num_delted_frames}/1546.') # should be 1546
================================================
FILE: main.py
================================================
import os
import argparse
import logging
import importlib
from trainers import task_to_trainer
import detectron2.utils.comm as comm
from termcolor import colored
import logging
import yaml
import torch
from utils.misc import setup_for_distributed
def _highlight(code, filename):
try:
import pygments
except ImportError:
return code
from pygments.lexers import Python3Lexer, YamlLexer
from pygments.formatters import Terminal256Formatter
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
return code
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
message = record.message
# message, asctime, name, filename = record.message, record.asctime, record.name, record.filename
log = super(_ColorfulFormatter, self).formatMessage(record)
if (record.levelno == logging.WARNING) or (record.levelno == logging.ERROR) or (record.levelno == logging.CRITICAL):
colored_message = colored(message, "red", attrs=["blink", "underline"])
elif record.levelno == logging.DEBUG:
colored_message = colored(message, "yellow", attrs=["blink", "underline"])
else: # INFO/NOTSET
colored_message = colored(message, "white")
return log + colored_message
def set_logging_file(output_dir, file_name, mode='a'):
handler1 = logging.StreamHandler()
handler2 = logging.FileHandler(os.path.join(output_dir, file_name), mode=mode)
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s %(filename)s]: ", "green"),
datefmt="%m/%d %H:%M:%S",
root_name=os.path.join(output_dir, file_name),
abbrev_name=str('grey'),
)
handler1.setFormatter(formatter)
handler2.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.addHandler(handler2)
logger.setLevel(logging.DEBUG)
def init_process_group_and_set_device(world_size, process_id, device_id):
"""
This function needs to be called on each spawned process to initiate learning using DistributedDataParallel.
The function initiates the process' process group and assigns it a single GPU to use during training.
"""
torch.cuda.set_device(device_id)
device = torch.device(f'cuda:{device_id}')
if world_size > 1:
torch.distributed.init_process_group(
torch.distributed.Backend.NCCL,
world_size=world_size,
rank=process_id
)
comm.create_local_process_group(world_size)
torch.distributed.barrier(device_ids=[device_id])
setup_for_distributed(process_id == 0)
return device
def run(rank, configs, world_size):
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'] = "4"
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = "1"
os.environ["DGLBACKEND"] = "pytorch"
logging.getLogger('penman').setLevel(logging.WARNING)
logging.getLogger('PIL').setLevel(logging.WARNING)
logging.getLogger('PIL.PngImagePlugin').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('h5py').setLevel(logging.WARNING)
init_process_group_and_set_device(world_size, process_id=rank, device_id=rank)
if comm.is_main_process():
mode = configs['trainer_mode']
out_dir = configs['out_dir']
if mode == 'eval':
num_of_eval_times = len([eval_txt for eval_txt in os.listdir(out_dir) if eval_txt.endswith('eval.txt')])
set_logging_file(out_dir, f"eval.txt", mode='a')
path = os.path.join(out_dir, f"config_eval.yaml")
else:
num_of_train_times = len([train_txt for train_txt in os.listdir(out_dir) if train_txt.endswith('train.txt')])
if 'resume' in mode:
set_logging_file(out_dir, f"train.txt", mode='a')
else:
set_logging_file(out_dir, f"train.txt", mode='w')
path = os.path.join(out_dir, f"config_train.yaml")
logging.debug("Running with full config:\n{}".format(_highlight(yaml.dump(configs, default_flow_style=False), ".yaml")))
with open(path, "w") as f:
f.write(yaml.dump(configs, default_flow_style=False))
logging.debug("Full config saved to {}".format(path))
comm.synchronize()
trainer = task_to_trainer[configs['task']](configs=configs)
comm.synchronize()
if configs['trainer_mode'] == 'eval':
eval_ckpts = configs['eval_ckpts']
for lunch in eval_ckpts:
trainer.load_ckpt(lunch, load_model=True, load_schedule=True, load_random=False, load_optimize=False)
trainer.evaluate()
else:
if configs['trainer_mode'] == 'train_resume':
ckpt_dirs = os.listdir(configs['out_dir'])
ckpt_dirs = sorted([a for a in ckpt_dirs if a.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1]))
trainer_ckpt = '/'.join([configs['out_dir'], ckpt_dirs[-1], 'ckpt.pth.tar'])
trainer.load_ckpt(trainer_ckpt, load_model=True, load_schedule=True, load_random=True, load_optimize=True)
trainer.train()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config_file', type=str, required=True)
parser.add_argument('--trainer_mode', type=str, default='train_attmpt')
parser.add_argument('--eval_path', type=str, default='')
args = parser.parse_args()
task, group, config, config2 = args.config_file.split('/')[-4:]
assert config == config2[:-3]
config_file = '.'.join(['output', task, group, config, config])
configs = importlib.import_module(config_file).trainer_configs
configs['task'], configs['group'], configs['config'] = task, group, config
configs['out_dir'] = os.path.join('./', 'output', task, group, config)
configs['trainer_mode'] = args.trainer_mode
if configs['trainer_mode'] == 'eval':
eval_ckpts = []
eval_path = args.eval_path
assert eval_path != '', f'eval path is none'
if os.path.isfile(eval_path):
eval_ckpts.append(eval_path)
elif os.path.isdir(eval_path):
ckpt_dirs = os.listdir(eval_path)
ckpt_dirs = [taylor for taylor in ckpt_dirs if os.path.isdir(os.path.join(eval_path, taylor))]
# epc[1]_iter[5000]_sap[60009]
ckpt_dirs = sorted([billie for billie in ckpt_dirs if billie.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1]))
eval_ckpts = [os.path.join(eval_path, cd, f'ckpt.pth.tar') for cd in ckpt_dirs]
eval_ckpts = [eval_c for eval_c in eval_ckpts if os.path.exists(eval_c)]
else:
raise ValueError()
configs['eval_ckpts'] = eval_ckpts
else:
pass
gpu_ids = list(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
assert len(set(gpu_ids)) == len(gpu_ids)
gpu_ids = list(range(len(gpu_ids)))
if len(gpu_ids) > 1:
torch.multiprocessing.spawn(run, nprocs=len(gpu_ids), args=(configs, len(gpu_ids)))
elif len(gpu_ids) == 1:
run(rank=0, configs=configs, world_size=len(gpu_ids))
================================================
FILE: models/VIS/BackboneEncoderDecoder_WithScaleConsistency.py
================================================
import matplotlib.pyplot as plt
from typing import Any, Optional, List, Dict, Set, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_schedule import build_schedule
from torch import Tensor
from einops import repeat, rearrange, reduce
from functools import partial
from einops.layers.torch import Rearrange
from torch import einsum
import numpy as np
import logging
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video, VIS_Aug_CallbackAPI
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann
import torchvision.transforms.functional as Trans_F
import copy
from models.registry import register_model
from models.optimization.optimizer import get_optimizer
from models.optimization.scheduler import build_scheduler
from models.backbone.utils import VideoMultiscale_Shape
from detectron2.modeling import BACKBONE_REGISTRY, META_ARCH_REGISTRY
class BackboneEncoderDecoder_WithScaleConsistency(nn.Module):
def __init__(
self,
configs,
pixel_mean = [0.485, 0.456, 0.406],
pixel_std = [0.229, 0.224, 0.225],):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) # 3 1 1
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
self.loss_weight = configs['model']['loss_weight']
video_backbone_configs = configs['model']['video_backbone']
video_backbone_cls = BACKBONE_REGISTRY.get(video_backbone_configs['name'])
self.video_backbone = video_backbone_cls(video_backbone_configs)
self.max_stride = self.video_backbone.max_stride
self.fusion_encoder = META_ARCH_REGISTRY.get(configs['model']['fusion']['name'])(configs['model']['fusion'],
multiscale_shapes=self.video_backbone.multiscale_shapes)
same_dim_multiscale_shapes = VideoMultiscale_Shape.set_multiscale_same_dim(shape_by_dim=self.video_backbone.multiscale_shapes,
same_dim=configs['model']['fusion']['d_model'])
self.decoder = META_ARCH_REGISTRY.get(configs['model']['decoder']['name'])(configs['model']['decoder'],
multiscale_shapes=same_dim_multiscale_shapes)
if configs['model']['fusion']['name'] == 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_v2':
self.fusion_encoder.hack_ref(query_norm=self.decoder.temporal_query_norm, mask_mlp=self.decoder.query_mask)
self.test_clip_size = configs['model']['test_clip_size']
@property
def device(self):
return self.pixel_mean.device
def model_preds(self, videos, video_aux_dict,):
if (not self.training) and (self.test_clip_size is not None):
nf = videos.shape[2]
clip_outputs = [] # list[dict]
for start_idx in range(0, nf, self.test_clip_size):
multiscales = self.video_backbone(x=videos[:, :, start_idx:(start_idx + self.test_clip_size)]) # b c t h w
multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict)
clip_outputs.append(self.decoder(multiscales, video_aux_dict=video_aux_dict)[-1]) # b t nq h w
return [{
'pred_masks': torch.cat([haosen['pred_masks'] for haosen in clip_outputs], dim=1), # b t n h w
'pred_class': torch.cat([haosen['pred_class'] for haosen in clip_outputs], dim=1),
}]
# b 3 t h w -> b 3 t h w
multiscales = self.video_backbone(x=videos) # b c t h w
multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict)
return self.decoder(multiscales, video_aux_dict=video_aux_dict)
def forward(self, batch_dict):
assert self.training
VIS_TrainAPI_clipped_video
videos = batch_dict['video_dict']['videos']
targets = batch_dict['targets']
batch_size, nf = videos.shape[:2]
videos = (videos - self.pixel_mean) / self.pixel_std
size1 = np.random.choice([256, 288, 320, 352, 384, 416, 448])
vid_1 = F.interpolate(videos.flatten(0, 1), size=size1, mode='bilinear')
vid_1 = rearrange(vid_1, '(b T) c h w -> b c T h w',b=batch_size, T=nf)
pred1 = self.model_preds(vid_1, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w}
pred1_loss = self.decoder.compute_loss(pred1, targets=targets, frame_targets=batch_dict['frame_targets'],
video_aux_dict=batch_dict['video_dict'])
loss_value_dict = {key: pred1_loss[key] for key in list(self.loss_weight.keys())}
return loss_value_dict, self.loss_weight
@torch.no_grad()
def sample(self, batch_dict):
assert not self.training
VIS_EvalAPI_clipped_video_request_ann
videos = batch_dict['video_dict']['videos'] # b t 3 h w, 0-1
orig_t, _, orig_h, orig_w = batch_dict['video_dict']['orig_sizes'][0]
videos = (videos - self.pixel_mean) / self.pixel_std
assert videos.shape[0] == 1
batch_size, T, _, H, W = videos.shape
videos = videos.permute(0, 2, 1,3,4) # b c t h w
decoder_output = self.model_preds(videos, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w}
if isinstance(decoder_output, list):
decoder_output = decoder_output[-1]
pred_masks = decoder_output['pred_masks'][0] # T n h w
pred_masks = F.interpolate(pred_masks, size=(H, W), mode='bilinear') > 0 # T n h w
pred_masks = pred_masks[:orig_t, :, :orig_h, :orig_w] # T n h w
#
pred_classes = decoder_output['pred_class'][0][:orig_t, :,:] # T n c, probability
pred_classes = pred_classes.cpu().unbind(0) # list[n c], T
pred_masks = pred_masks.cpu().unbind(0) # list[n h w], T
VIS_Aug_CallbackAPI
orig_video = videos[0][:, :orig_t, :orig_h, :orig_w].permute(1,0,2,3) # T 3 h w
orig_video = Trans_F.normalize(orig_video, [0, 0, 0], 1 / self.pixel_std)
orig_video = Trans_F.normalize(orig_video, -self.pixel_mean, [1, 1, 1]).cpu()
return {
'video': [orig_video], # [t 3 h w], 1
'pred_masks': [pred_masks], # [list[n h w], t, bool], 1
'pred_class': [pred_classes], # [list[n c], t, probability], 1
}
@staticmethod
def get_optim_params_group(model, configs):
weight_decay_norm = configs['optim']['weight_decay_norm']
weight_decay_embed = configs['optim']['weight_decay_embed']
defaults = {}
defaults['lr'] = configs['optim']['base_lr']
defaults['weight_decay'] = configs['optim']['weight_decay']
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
log_lr_group_idx = {'backbone':None, 'base':None}
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if "video_backbone" in module_name:
hyperparams["lr"] = hyperparams["lr"] * configs['optim']['backbone_lr_multiplier']
if log_lr_group_idx['backbone'] is None:
log_lr_group_idx['backbone'] = len(params)
else:
if log_lr_group_idx['base'] is None:
log_lr_group_idx['base'] = len(params)
# pos_embed, norm, embedding的weight decay特殊对待
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
):
logging.debug(f'setting weight decay of {module_name}.{module_param_name} to zero')
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
params.append({"params": [value], **hyperparams})
return params, log_lr_group_idx
@register_model
def backbone_encoder_decoder_withScaleConsistency(configs, device):
from .aux_mapper import AUXMapper_v1
model = BackboneEncoderDecoder_WithScaleConsistency(configs)
model.to(device)
params_group, log_lr_group_idx = BackboneEncoderDecoder_WithScaleConsistency.get_optim_params_group(model=model, configs=configs)
to_train_num_parameters = len([n for n, p in model.named_parameters() if p.requires_grad])
assert len(params_group) == to_train_num_parameters, f''
optimizer = get_optimizer(params_group, configs)
scheduler = build_scheduler(configs=configs, optimizer=optimizer)
model_input_mapper = AUXMapper_v1(configs['model']['input_aux'])
train_samplers, train_loaders, eval_function = build_schedule(configs,
model_input_mapper.mapper,
partial(model_input_mapper.collate, max_stride=model.max_stride))
return model, optimizer, scheduler, train_samplers, train_loaders, log_lr_group_idx, eval_function
================================================
FILE: models/VIS/__init__.py
================================================
from . import BackboneEncoderDecoder_WithScaleConsistency
from .. import modality_input_mappers
from .. import backbone
from .. import decoder
from .. import encoder
================================================
FILE: models/VIS/aux_mapper.py
================================================
import torch
from torch.nn import functional as F
from models.registry import register_model
from data_schedule.utils.box_ops import box_xyxy_to_cxcywh
from models.registry import MODELITY_INPUT_MAPPER_REGISTRY
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann
from utils.misc import nested_tensor_from_videos_list_with_stride
class AUXMapper_v1:
def __init__(self, aux_configs):
video_auxes = aux_configs['video_auxes']
video_auxes_names = [config['name'] for config in video_auxes]
assert len(list(set(video_auxes_names))) == len(video_auxes_names), '每个aux的名字必须不一样'
self.video_auxes_names = video_auxes_names
self.video_auxes = [MODELITY_INPUT_MAPPER_REGISTRY.get(config['name'])(config) for config in video_auxes]
self.targets_auxes = None
def mapper(self, data_dict, mode,):
if mode == 'train':
VIS_TrainAPI_clipped_video
video = data_dict['video_dict']['video']
for aux, aux_name in zip(self.video_auxes, self.video_auxes_names):
data_dict['video_dict'][aux_name] = aux.mapper(video)
elif mode == 'evaluate':
VIS_EvalAPI_clipped_video_request_ann
video = data_dict['video_dict']['video']
for aux, aux_name in zip(self.video_auxes, self.video_auxes_names):
data_dict['video_dict'][aux_name] = aux.mapper(video)
else:
raise ValueError()
return data_dict
def collate(self, batch_dict, mode, max_stride):
if mode == 'train':
VIS_TrainAPI_clipped_video
video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride)
targets = [sample['targets'] for sample in batch_dict]
frame_has_ann = [clip_tgt['has_ann'] for clip_tgt in targets] # list[t], b
frame_targets = [sample['frame_targets'] for sample in batch_dict]
_, pad_T, _, pad_H, pad_W = video_dict['videos'].shape
targets = self.collate_targets(targets=targets, pad_H=pad_H, pad_W=pad_W, pad_T=pad_T)
frame_targets = self.collate_frame_targets(frame_targets=frame_targets,
frame_has_ann=frame_has_ann,
pad_H=pad_H, pad_W=pad_W, pad_T=pad_T)
ret = {
'video_dict': video_dict,
'targets': targets,
'frame_targets': frame_targets,
'meta_idxs': [sample['meta_idx'] for sample in batch_dict],
'visualize': [sample['visualize'] for sample in batch_dict],
}
elif mode == 'evaluate':
VIS_EvalAPI_clipped_video_request_ann
assert len(batch_dict) == 1
video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride) # 不pad
metas = [sample['meta'] for sample in batch_dict]
collated_metas = {}
for key in metas[0].keys():
collated_metas[key] = [mt[key] for mt in metas]
ret = {
'video_dict': video_dict,
'metas': collated_metas,
'meta_idxs': [sample['meta_idx'] for sample in batch_dict],
'visualize': [sample['visualize'] for sample in batch_dict],
}
debug_data = False
if debug_data:
self.visualize_input_target_for_debug_data(ret) # ./test.png
return ret
def collate_video_dict(self, batch_dict, max_stride):
videos = [sample['video_dict']['video'] for sample in batch_dict] # list[ti 3 hi wi] -> b T 3 H W
orig_sizes = [list(vid.shape) for vid in videos] # t 3 h w
if type(max_stride) == int: # temporal max stride 为1, spatial max stride
pad_stride = [1, max_stride]
if (type(max_stride) == list) and (len(max_stride) == 2):
pad_stride = max_stride
videos = nested_tensor_from_videos_list_with_stride(videos, max_stride=pad_stride).tensors # b t c h w
video_dicts = {'videos': videos, 'orig_sizes': orig_sizes}
for aux_name, aux in zip(self.video_auxes_names, self.video_auxes):
auxes = [sample['video_dict'][aux_name] for sample in batch_dict] # list[dict] / list[tensor]
collated_auxes = aux.collate(auxes, batch_videos=videos) # list[dict] / tensor
if isinstance(auxes[0], dict):
keys = collated_auxes.keys()
for key in keys:
assert key not in video_dicts
video_dicts[key] = collated_auxes[key]
else:
video_dicts[aux_name] = collated_auxes
return video_dicts
def collate_frame_targets(self, frame_targets, frame_has_ann, pad_H, pad_W, pad_T): #
VIS_TrainAPI_clipped_video
ret = {}
has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in frame_has_ann], dim=0).flatten() # bT
ret['has_ann'] = has_ann
masks = [ftarget['masks'] for sample in frame_targets for ftarget in sample] # list[ni h w], bt'
masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2])).bool() for m in masks] # list[ni H W], bt'
ret['masks'] = masks # list[ni h w], bt'
classes = [ftarget['classes'] for sample in frame_targets for ftarget in sample] # list[ni], bt'
ret['classes'] = classes
if 'boxes' in frame_targets[0][0]:
boxes = [ftarget['boxes'] for sample in frame_targets for ftarget in sample] # list[ni 4], x1y1x2y2, bt'
boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes]
boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=bx.dtype) for bx in boxes] # 0-1
ret['boxes'] = boxes # list[ni 4], bt'
return ret
def collate_targets(self, targets, pad_H, pad_W, pad_T):
VIS_TrainAPI_clipped_video
has_ann = [sample['has_ann'] for sample in targets] # list[t], bool
has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in has_ann], dim=0) # b T
masks = [sample['masks'] for sample in targets]
masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2]), value=0.).bool() \
for m in masks] # list[ni T' H W]
classes = [sample['classes'] for sample in targets]
ret = {
'masks': masks, # list[ni T' h w]
'has_ann': has_ann, # b T
'classes': classes, # list[ni], b
}
if 'boxes' in targets[0]:
boxes = [sample['boxes'] for sample in targets] # list[ni T' 4], x1y1x2y2
boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes]
boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=torch.float) for bx in boxes] # 0-1
ret.update({'boxes': boxes,})
return ret
def visualize_input_target_for_debug_data(self, ret):
videos = ret['video_dict']['videos'] # b T 3 H W
pass
================================================
FILE: models/__init__.py
================================================
import os
from .registry import model_entrypoint
if os.getenv('CURRENT_TASK') == 'VIS':
from . import VIS
else:
raise ValueError()
================================================
FILE: models/backbone/__init__.py
================================================
from . import res2net, pvtv2
================================================
FILE: models/backbone/pvtv2.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from timm.models.registry import register_model
class DWConv(nn.Module):
def __init__(self, dim):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x, H, W):
B,N,C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x, H, W):
x = self.fc1(x)
x = F.gelu(self.dwconv(x, H, W))
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads, sr_ratio):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.num_heads = num_heads
self.scale = (dim//num_heads)**(-0.5)
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio, drop_path, sr_ratio):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio))
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size, stride, in_chans, embed_dim):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size//2, patch_size//2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
B,C,H,W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PVT(nn.Module):
def __init__(self, embed_dims, mlp_ratios, depths, snapshot, sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.depths = depths
self.snapshot = snapshot
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=3, embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(dim=embed_dims[0], num_heads=1, mlp_ratio=mlp_ratios[0], drop_path=dpr[cur + i], sr_ratio=sr_ratios[0]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0], eps=1e-6)
cur += depths[0]
self.block2 = nn.ModuleList([Block(dim=embed_dims[1], num_heads=2, mlp_ratio=mlp_ratios[1], drop_path=dpr[cur + i], sr_ratio=sr_ratios[1]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
cur += depths[1]
self.block3 = nn.ModuleList([Block(dim=embed_dims[2], num_heads=5, mlp_ratio=mlp_ratios[2], drop_path=dpr[cur + i], sr_ratio=sr_ratios[2]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
cur += depths[2]
self.block4 = nn.ModuleList([Block(dim=embed_dims[3], num_heads=8, mlp_ratio=mlp_ratios[3], drop_path=dpr[cur + i], sr_ratio=sr_ratios[3]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
state_dict:dict = torch.load(self.snapshot, map_location='cpu')
state_dict.pop("head.weight")
state_dict.pop("head.bias")
self.load_state_dict(state_dict, strict=True)
del state_dict
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def forward(self, x):
B = x.shape[0]
# stage 1
out1, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
out1 = blk(out1, H, W)
out1 = self.norm1(out1).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 2
out2, H, W = self.patch_embed2(out1)
for i, blk in enumerate(self.block2):
out2 = blk(out2, H, W)
out2 = self.norm2(out2).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 3
out3, H, W = self.patch_embed3(out2)
for i, blk in enumerate(self.block3):
out3 = blk(out3, H, W)
out3 = self.norm3(out3).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 4
out4, H, W = self.patch_embed4(out3)
for i, blk in enumerate(self.block4):
out4 = blk(out4, H, W)
out4 = self.norm4(out4).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return out1, out2, out3, out4
from detectron2.modeling import BACKBONE_REGISTRY
from einops import rearrange, reduce, repeat
from .utils import VideoMultiscale_Shape, ImageMultiscale_Shape
import os
import time
@BACKBONE_REGISTRY.register()
class PVT_V2(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
pt_path = os.getenv('PT_PATH')
pvt_v2 = PVT(embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 6, 3], snapshot=os.path.join(pt_path, 'pvt_v2/pvt_v2_b2.pth'))
self.pvt_v2 = pvt_v2
freeze = configs['freeze']
if freeze:
for p in self.parameters():
p.requires_grad_(False)
self.multiscale_shapes = {}
for name, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[4, 8, 16, 32],
[64, 128, 320, 512]):
self.multiscale_shapes[name] = ImageMultiscale_Shape(spatial_stride=spatial_stride, dim=dim)
self.max_stride = 32
def forward(self, x):
if not self.training:
batch_feats = []
for haosen in x:
feats = self.pvt_v2(haosen.unsqueeze(0))
batch_feats.append(feats)
batch_feats = list(zip(*batch_feats)) # 4
batch_feats = [torch.cat(haosen, dim=0) for haosen in batch_feats] # list[bt c h w]
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, batch_feats):
ret[name] = feat
return ret
else:
layer_outputs = self.pvt_v2(x)
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, layer_outputs):
ret[name] = feat
return ret
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@BACKBONE_REGISTRY.register()
class Video2D_PVT_V2(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
self.image_homo = PVT_V2(configs=configs)
self.multiscale_shapes = {}
for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[1, 1, 1, 1],
[4, 8, 16, 32],
[64, 128, 320, 512]):
self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride,
spatial_stride=spatial_stride, dim=dim)
self.max_stride = [1, 32]
def forward(self, x):
batch_size, _, T = x.shape[:3]
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
layer_outputs = self.image_homo(x)
layer_outputs = {key: rearrange(value.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T).contiguous() \
for key, value in layer_outputs.items()}
return layer_outputs
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
================================================
FILE: models/backbone/res2net.py
================================================
import math
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from detectron2.modeling import BACKBONE_REGISTRY
from einops import rearrange, reduce, repeat
from .utils import VideoMultiscale_Shape
class Bottle2neck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
super(Bottle2neck, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = 1 if scale == 1 else scale - 1
if stype == 'stage':
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
convs, bns = [], []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.downsample = downsample
self.stype = stype
self.scale = scale
self.width = width
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
sp = spx[i] if i == 0 or self.stype == 'stage' else sp + spx[i]
sp = self.convs[i](sp)
sp = F.relu(self.bns[i](sp), inplace=True)
out = sp if i == 0 else torch.cat((out, sp), 1)
if self.scale != 1 and self.stype == 'normal':
out = torch.cat((out, spx[self.nums]), 1)
elif self.scale != 1 and self.stype == 'stage':
out = torch.cat((out, self.pool(spx[self.nums])), 1)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
x = self.downsample(x)
return F.relu(out+x, inplace=True)
class Res2Net(nn.Module):
def __init__(self, layers, snapshot, baseWidth=26, scale=4):
super(Res2Net, self).__init__()
self.inplanes = 64
self.snapshot = snapshot
self.baseWidth = baseWidth
self.scale = scale
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 1, 1, bias=False)
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(Bottle2neck, 64, layers[0])
self.layer2 = self._make_layer(Bottle2neck, 128, layers[1], stride=2)
self.layer3 = self._make_layer(Bottle2neck, 256, layers[2], stride=2)
self.layer4 = self._make_layer(Bottle2neck, 512, layers[3], stride=2)
state_dict:dict = torch.load(self.snapshot, map_location='cpu')
self.load_state_dict(state_dict, strict=False)
del state_dict
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample=downsample, stype='stage', baseWidth=self.baseWidth, scale=self.scale)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))
return nn.Sequential(*layers)
def forward(self, x):
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
out2 = self.layer1(out1)
out3 = self.layer2(out2)
out4 = self.layer3(out3)
out5 = self.layer4(out4)
return out2, out3, out4, out5
def initialize(self):
self.load_state_dict(torch.load(self.snapshot), strict=False)
@BACKBONE_REGISTRY.register()
class Res2Net_50_EachFrame(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
pt_path = os.getenv('PT_PATH')
res2net = Res2Net([3, 4, 6, 3], os.path.join(pt_path, 'res2net/res2net50_v1b_26w_4s-3cf99910.pth'))
self.res2net = res2net
freeze = configs['freeze']
if freeze:
for p in self.parameters():
p.requires_grad_(False)
self.multiscale_shapes = {}
for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[1, 1, 1, 1],
[4, 8, 16, 32],
[256, 512, 1024, 2048]):
self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride,
spatial_stride=spatial_stride, dim=dim)
self.max_stride = [1, 32]
def forward(self, x):
batch_size, _, T = x.shape[:3]
x = rearrange(x, 'b c t h w -> (b t) c h w')
layer_outputs = self.res2net(x)
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, layer_outputs):
ret[name] = rearrange(feat.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T)
return ret
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
================================================
FILE: models/backbone/utils.py
================================================
class VideoMultiscale_Shape:
def __init__(self, temporal_stride, spatial_stride, dim) -> None:
self.temporal_stride = temporal_stride
self.spatial_stride = spatial_stride
self.dim = dim
@staticmethod
def set_multiscale_same_dim(shape_by_dim, same_dim):
return {
key: VideoMultiscale_Shape(temporal_stride=value.temporal_stride,
spatial_stride=value.spatial_stride,
dim=same_dim) for key,value in shape_by_dim.items()
}
class ImageMultiscale_Shape:
def __init__(self, spatial_stride, dim) -> None:
self.spatial_stride = spatial_stride
self.dim = dim
================================================
FILE: models/decoder/__init__.py
================================================
from . import mask2former_video
================================================
FILE: models/decoder/mask2former_video.py
================================================
# multi-scale features, b c h w -> module -> obj queries, predictions, b nq c
import torch.nn as nn
from models.layers.decoder_layers import CrossAttentionLayer, SelfAttentionLayer, FFNLayer
from models.layers.anyc_trans import MLP
import torch.nn.functional as F
import torch
import copy
from models.layers.utils import zero_module, _get_clones
from models.layers.position_encoding import build_position_encoding
from einops import rearrange, reduce, repeat
from scipy.optimize import linear_sum_assignment
from models.layers.matching import batch_dice_loss, batch_sigmoid_ce_loss, batch_sigmoid_focal_loss, dice_loss, ce_mask_loss
from detectron2.modeling import META_ARCH_REGISTRY
import detectron2.utils.comm as comm
import data_schedule.utils.box_ops as box_ops
from models.layers.utils import zero_module
from utils.misc import is_dist_avail_and_initialized
from collections import defaultdict
from detectron2.projects.point_rend.point_features import point_sample
from torch.cuda.amp import autocast
from detectron2.projects.point_rend.point_features import (
get_uncertain_point_coords_with_randomness,
point_sample,
)
def calculate_uncertainty(logits):
"""
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
foreground class in `classes`.
Args:
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images and C is
the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
assert logits.shape[1] == 1
gt_class_logits = logits.clone()
return -(torch.abs(gt_class_logits))
class Video_SetMatchingLoss(nn.Module):
def __init__(self,
loss_config,
num_classes,) -> None:
super().__init__()
self.num_classes = num_classes # n=1 / n=0 / n>1
self.matching_metrics = loss_config['matching_metrics'] # mask: mask/dice; point_sample_mask: ..
self.losses = loss_config['losses']
self.aux_layer_weights = loss_config['aux_layer_weights'] # int/list
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = loss_config['background_cls_eos']
self.register_buffer('empty_weight', empty_weight)
# self.register_buffer('small_obj_weight', torch.tensor(loss_config['small_obj_weight']).float())
self._warmup_iters = 2000
self.register_buffer("_iter", torch.zeros([1]))
@property
def device(self,):
return self.empty_weight.device
def compute_loss(self,
model_outs,
targets,
video_aux_dict,
**kwargs):
# list[n t' h w], batch
if 'masks' in targets:
num_objs = sum([haosen.flatten(1).any(-1).int().sum().item() for haosen in targets['masks']])
# list[n t' 4], batch
elif 'boxes' in targets:
# n t' 2 -> n t -> n
num_objs = sum([(haosen[:, :, 2:] > 0).all(-1).any(-1).int().sum().item() for haosen in targets['boxes']])
else:
raise ValueError('targets里没有boxes/masks, 需要确定数量')
num_objs = torch.as_tensor([num_objs], dtype=torch.float, device=self.device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_objs)
num_objs = torch.clamp(num_objs / comm.get_world_size(), min=1).item()
if isinstance(self.aux_layer_weights, list):
assert len(self.aux_layer_weights) == (len(model_outs) - 1)
else:
self.aux_layer_weights = [self.aux_layer_weights] * (len(model_outs) - 1)
layer_weights = self.aux_layer_weights + [1.]
loss_values = {
'mask_dice':0., 'mask_ce':0.,
'box_l1': 0., 'box_giou': 0.,
'class_ce':0.,
'mask_dice_smobj':0., 'mask_ce_smobj':0.,
'boxMask_dice':0., 'boxMask_ce':0.,
}
if ('mask_ce_dice' in self.matching_metrics) or ('mask_ce_dice' in self.losses):
# mask interpolate
tgt_mask_shape = targets['masks'][0].shape[-2:] # list[n t H W], b
for layer_idx in range(len(model_outs)):
# b t nq h w
batch_size, nf = model_outs[layer_idx]['pred_masks'].shape[:2]
model_outs[layer_idx]['pred_masks'] = rearrange(F.interpolate(model_outs[layer_idx]['pred_masks'].flatten(0, 1),
size=tgt_mask_shape, mode='bilinear', align_corners=False),
'(b t) n h w -> b t n h w',b=batch_size, t=nf)
for taylor, layer_out in zip(layer_weights, model_outs):
if taylor != 0:
matching_indices = self.matching(layer_out, targets)
for loss in self.losses:
loss_extra_param = self.losses[loss]
if loss == 'mask_dice_ce' :
loss_dict = self.loss_mask_dice_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
elif loss == 'class_ce':
loss_dict = self.loss_class_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
elif loss == 'point_mask_dice_ce':
loss_dict = self.loss_point_mask_dice_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
else:
raise ValueError()
for key, value in loss_dict.items():
loss_values[key] = loss_values[key] + value
return loss_values
@torch.no_grad()
def matching(self, layer_out, targets):
batch_size = len(targets['masks']) if 'masks' in targets else len(targets['boxes'])
indices = []
has_ann = targets['has_ann']
for i in range(batch_size):
C = 0.
if 'class_prob' in self.matching_metrics:
out_cls = layer_out['pred_class'][i].softmax(-1) # nq c
tgt_cls = targets['classes'][i] # n
cost_class = - out_cls[:, tgt_cls] # nq n
C += self.matching_metrics['class_prob']['prob'] * cost_class
if 'mask_dice_ce' in self.matching_metrics:
out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w
tgt_mask = targets['masks'][i].to(out_mask) # ni t' H W
cost_mask = batch_sigmoid_ce_loss(out_mask.flatten(1), tgt_mask.flatten(1))
cost_dice = batch_dice_loss(out_mask.flatten(1), tgt_mask.flatten(1))
C += self.matching_metrics['mask_dice_ce']['ce'] * cost_mask + \
self.matching_metrics['mask_dice_ce']['dice'] * cost_dice
if 'point_mask_dice_ce' in self.matching_metrics:
out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w
tgt_mask = targets['masks'][i].to(out_mask)# ni t' H W
nf = out_mask.shape[1]
out_mask = out_mask.flatten(0, 1)[:, None]
tgt_mask = tgt_mask.flatten(0, 1)[:, None]
# all masks share the same set of points for efficient matching!
point_coords = torch.rand(1, self.matching_metrics['point_mask_dice_ce']['num_points'],
2, device=self.device)
# get gt labels
tgt_mask = point_sample(
tgt_mask,
point_coords.repeat(tgt_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1) # nqt s
tgt_mask = rearrange(tgt_mask, '(nq t) s -> nq t s',t=nf)
out_mask = point_sample(
out_mask,
point_coords.repeat(out_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1) # nit s
out_mask = rearrange(out_mask, '(nq t) s -> nq t s',t=nf)
with autocast(enabled=False):
out_mask = out_mask.float().flatten(1) # nq num_points
tgt_mask = tgt_mask.float().flatten(1)
cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask)
cost_dice = batch_dice_loss(out_mask, tgt_mask)
C += self.matching_metrics['point_mask_dice_ce']['ce'] * cost_mask + \
self.matching_metrics['point_mask_dice_ce']['dice'] * cost_dice
C = C.cpu()
indices.append(linear_sum_assignment(C))
return [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]
def loss_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
has_ann = targets['has_ann'] # b t
src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w
tgt_masks = targets['masks'] # list[n t' h w]
# list[nq t' h w] -> n_sigma t' h w
src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0)
tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0)
tgt_masks = tgt_masks.to(src_masks)
losses = {
"mask_ce": ce_mask_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs),
"mask_dice": dice_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs),
}
return losses
def loss_point_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
has_ann = targets['has_ann'] # b t
src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w
tgt_masks = targets['masks'] # list[n t' h w]
# list[nq t' h w] -> n_sigma t' h w
src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0)
tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0)
tgt_masks = tgt_masks.to(src_masks)
nf = src_masks.shape[1]
# No need to upsample predictions as we are using normalized coordinates :)
# NT x 1 x H x W
src_masks = src_masks.flatten(0, 1).unsqueeze(1).contiguous() # nt' 1 h w
target_masks = tgt_masks.flatten(0, 1).unsqueeze(1).contiguous()
with torch.no_grad():
# sample point_coords
point_coords = get_uncertain_point_coords_with_randomness(
src_masks,
lambda logits: calculate_uncertainty(logits),
loss_extra_param['num_points'],
loss_extra_param['oversample_ratio'],
loss_extra_param['importance_sample_ratio'],
)
# get gt labels
point_labels = point_sample(
target_masks,
point_coords,
align_corners=False,
).squeeze(1) # nt' s
point_logits = point_sample(
src_masks,
point_coords,
align_corners=False,
).squeeze(1) # nt' s
# point_logits = rearrange(point_logits, '(n t) s -> n (t s)',t=nf)
# point_labels = rearrange(point_labels, '(n t) s -> n (t s)',t=nf)
losses = {
"mask_dice": ce_mask_loss(point_logits, point_labels, num_objs),
"mask_ce": dice_loss(point_logits, point_labels, num_objs),
}
del src_masks
del target_masks
return losses
def loss_class_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
src_logits = outputs["pred_class"].float() # b nq c
idx = self._get_src_permutation_idx(indices)
# list[n], b -> bn
target_classes_o = torch.cat([t[J] for t, (_, J) in zip(targets['classes'], indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=self.device
)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"class_ce": loss_ce}
return losses
def loss_box_l1_giou(self, outputs, targets, indices, num_objs, loss_extra_param):
tgt_boxes = targets['boxes'] # list[n tl 4], b
has_ann = targets['has_ann'] # b t
src_boxes = outputs['pred_boxes'].sigmoid().permute(0, 2, 1, 3).contiguous() # b nq t 4
src_boxes = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_boxes, indices, has_ann)], dim=0) # n_sum t' 4
tgt_boxes = torch.cat([t[J] for t, (_, J) in zip(tgt_boxes, indices)], dim=0) # n_sum t' 4
nf = tgt_boxes.shape[1]
loss_l1 = F.l1_loss(src_boxes, tgt_boxes, reduction='none').flatten(1) # n_sum t'4
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)),
box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) # n_sumt'
loss_giou = loss_giou.view(-1, nf).contiguous()
return {
'box_l1': loss_l1.sum(-1).sum() / num_objs,
'box_giou': loss_giou.sum(-1).sum() / num_objs
}
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
@META_ARCH_REGISTRY.register()
class Video_MaskedAttn_MultiscaleMaskDecoder_v3(nn.Module):
def __init__(self,
configs,
multiscale_shapes):
super().__init__()
d_model = configs['d_model']
attn_configs = configs['attn']
self.video_nqueries = configs['video_nqueries']
self.nlayers = configs['nlayers']
self.memory_scales = configs['memory_scales']
self.mask_scale = configs['mask_scale']
self.mask_spatial_stride = multiscale_shapes[self.mask_scale].spatial_stride
num_classes = configs['num_classes']
inputs_projs = configs['inputs_projs']
self.inputs_projs = nn.Sequential()
if inputs_projs is not None:
self.inputs_projs = META_ARCH_REGISTRY.get(inputs_projs['name'])(inputs_projs,
multiscale_shapes=multiscale_shapes,
out_dim=d_model)
self.level_embeds = nn.Embedding(len(self.memory_scales), d_model)
assert self.nlayers % len(self.memory_scales) == 0
self.cross_layers = _get_clones(CrossAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.self_layers = _get_clones(SelfAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.ffn_layers = _get_clones(FFNLayer(d_model=d_model,
dim_feedforward=attn_configs['dim_feedforward'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.nheads = attn_configs['nheads']
self.temporal_query_poses = nn.Embedding(self.video_nqueries, d_model)
self.temporal_query_feats = nn.Embedding(self.video_nqueries, d_model)
self.temporal_query_norm = nn.LayerNorm(d_model)
self.pos_3d = build_position_encoding(hidden_dim=d_model, position_embedding_name='3d') # b t c h w
self.head_outputs = configs['head_outputs']
assert 'mask' in self.head_outputs
self.query_mask = MLP(d_model, d_model, d_model, 3)
if 'class' in self.head_outputs:
self.query_class = nn.Linear(d_model, num_classes + 1)
self.loss_module = Video_SetMatchingLoss(loss_config=configs['loss'], num_classes=num_classes)
@property
def device(self,):
return self.temporal_query_feats.weight.device
def get_memories_and_mask_features(self, multiscales):
# b c t h w
memories = [multiscales[scale] for scale in self.memory_scales]
size_list = [mem_feat.shape[-2:] for mem_feat in memories]
memories_poses = [self.pos_3d(mem.permute(0, 2, 1,3, 4)).permute(0, 2, 1, 3, 4) for mem in memories] # b c t h w
memories = [rearrange(mem, 'b c t h w -> (t h w) b c').contiguous() for mem in memories]
memories_poses = [rearrange(mem_pos, 'b c t h w -> (t h w) b c').contiguous() for mem_pos in memories_poses]
mask_features = multiscales[self.mask_scale] # b c t h w
return memories, memories_poses, mask_features, size_list
def forward(self,
multiscales, # b c t h w
video_aux_dict=None
):
multiscales = self.inputs_projs(multiscales[0])
# thw b c; b c t h w
memories, memories_poses, mask_features, size_list = self.get_memories_and_mask_features(multiscales)
memories = [mem_feat + self.level_embeds.weight[i][None, None, :] for i, mem_feat in enumerate(memories)]
batch_size, _, nf, *_ = mask_features.shape
# nq b c
temporal_query_poses = self.temporal_query_poses.weight.unsqueeze(1).repeat(1, batch_size, 1)
temporal_query_feats = self.temporal_query_feats.weight.unsqueeze(1).repeat(1, batch_size, 1)
vid_ret = []
# b nq class, b nq t h w; b*head nq thw
vid_class, vid_mask, attn_mask = \
self.forward_heads(temporal_query_feats=temporal_query_feats,
mask_features=mask_features, attn_mask_target_size=size_list[0]) # first sight you re not human
vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask})
for i in range(self.nlayers):
level_index = i % len(self.memory_scales)
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 全masked掉的 全注意, 比如有padding
temporal_query_feats = self.cross_layers[i](
tgt=temporal_query_feats, # nq b c
memory=memories[level_index], # thw b c
memory_mask=attn_mask, # b*h nq thw
memory_key_padding_mask=None,
pos=memories_poses[level_index], # thw b c
query_pos=temporal_query_poses, # nq b c
)
temporal_query_feats = self.self_layers[i](
temporal_query_feats,
tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=temporal_query_poses,
)
temporal_query_feats = self.ffn_layers[i](
temporal_query_feats
)
# b nq class, b nq t h w
vid_class, vid_mask, attn_mask = \
self.forward_heads(temporal_query_feats=temporal_query_feats,
mask_features=mask_features, attn_mask_target_size=size_list[(i + 1) % len(self.memory_scales)]) # first sight you re not human
vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask})
return vid_ret
def forward_heads(self, temporal_query_feats, mask_features, attn_mask_target_size): # nq b c; b c t h w
batch_size, _, nf, *_ = mask_features.shape
temporal_query_feats = self.temporal_query_norm(temporal_query_feats) # nq b c
temporal_query_feats = temporal_query_feats.transpose(0, 1).contiguous() # b nq c
class_logits = self.query_class(temporal_query_feats) if 'class' in self.head_outputs else None # b n class+1
mask_embeds = self.query_mask(temporal_query_feats) # b n c
mask_logits = torch.einsum("bqc,bcthw->bqthw", mask_embeds, mask_features)
gitextract__wc8fr5u/
├── .gitignore
├── README.md
├── assets/
│ ├── DATA.md
│ ├── INSTALL.md
│ └── MODEL_ZOO.md
├── data_schedule/
│ ├── __init__.py
│ ├── registry.py
│ ├── utils/
│ │ ├── box_ops.py
│ │ ├── sampler.py
│ │ └── segmentation.py
│ └── vis/
│ ├── __init__.py
│ ├── apis.py
│ ├── evaluator_fast.py
│ ├── evaluator_utils.py
│ ├── fibroid/
│ │ ├── __init__.py
│ │ ├── evals.py
│ │ ├── fibroid_dataset.py
│ │ ├── fibroid_utils.py
│ │ └── metrics.py
│ ├── mapper.py
│ ├── mapper_utils.py
│ ├── polyp/
│ │ ├── __init__.py
│ │ ├── evals.py
│ │ ├── polyp_dataset.py
│ │ └── polyp_utils.py
│ ├── vis_aug_eval.py
│ ├── vis_aug_train.py
│ ├── vis_aug_utils.py
│ └── vis_frame_sampler.py
├── handle_vps.py
├── main.py
├── models/
│ ├── VIS/
│ │ ├── BackboneEncoderDecoder_WithScaleConsistency.py
│ │ ├── __init__.py
│ │ └── aux_mapper.py
│ ├── __init__.py
│ ├── backbone/
│ │ ├── __init__.py
│ │ ├── pvtv2.py
│ │ ├── res2net.py
│ │ └── utils.py
│ ├── decoder/
│ │ ├── __init__.py
│ │ └── mask2former_video.py
│ ├── encoder/
│ │ ├── __init__.py
│ │ ├── input_projs.py
│ │ ├── localGlobal.py
│ │ ├── neighborhood_qk.py
│ │ └── ops/
│ │ ├── MultiScaleDeformableAttention.egg-info/
│ │ │ └── PKG-INFO
│ │ ├── attention.py
│ │ ├── build/
│ │ │ ├── lib.linux-x86_64-cpython-311/
│ │ │ │ ├── functions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ └── modules/
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn.py
│ │ │ ├── lib.linux-x86_64-cpython-38/
│ │ │ │ ├── functions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ └── modules/
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn.py
│ │ │ ├── temp.linux-x86_64-cpython-311/
│ │ │ │ ├── .ninja_deps
│ │ │ │ ├── .ninja_log
│ │ │ │ ├── build.ninja
│ │ │ │ └── home/
│ │ │ │ └── xhh/
│ │ │ │ └── workspace/
│ │ │ │ └── rvos_encoder/
│ │ │ │ └── models/
│ │ │ │ └── ops/
│ │ │ │ └── src/
│ │ │ │ ├── cpu/
│ │ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ │ ├── cuda/
│ │ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ │ └── vision.o
│ │ │ └── temp.linux-x86_64-cpython-38/
│ │ │ └── home/
│ │ │ └── xhh/
│ │ │ └── workspace/
│ │ │ └── ReferFormer/
│ │ │ └── models/
│ │ │ └── ops/
│ │ │ └── src/
│ │ │ ├── cpu/
│ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ ├── cuda/
│ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ └── vision.o
│ │ ├── dist/
│ │ │ ├── MultiScaleDeformableAttention-1.0-py3.11-linux-x86_64.egg
│ │ │ └── MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg
│ │ ├── functions/
│ │ │ ├── __init__.py
│ │ │ └── ms_deform_attn_func.py
│ │ ├── make.sh
│ │ ├── modules/
│ │ │ ├── __init__.py
│ │ │ ├── frame_query_ss2d.py
│ │ │ └── ms_deform_attn.py
│ │ ├── setup.py
│ │ ├── src/
│ │ │ ├── cpu/
│ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ ├── cuda/
│ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ ├── ms_deform_attn.h
│ │ │ └── vision.cpp
│ │ └── test.py
│ ├── layers/
│ │ ├── anyc_trans.py
│ │ ├── decoder_layers.py
│ │ ├── gilbert/
│ │ │ ├── demo/
│ │ │ │ ├── index.html
│ │ │ │ ├── normalize.css
│ │ │ │ ├── script.js
│ │ │ │ ├── skeleton.css
│ │ │ │ └── two.js
│ │ │ ├── gilbert2d.py
│ │ │ ├── gilbert3d.py
│ │ │ ├── gilbert_d2xy.py
│ │ │ ├── gilbert_d2xyz.py
│ │ │ ├── gilbert_xy2d.py
│ │ │ ├── gilbert_xyz2d.py
│ │ │ ├── plotpath.m
│ │ │ ├── ports/
│ │ │ │ ├── Makefile
│ │ │ │ ├── gilbert.c
│ │ │ │ └── gilbert.js
│ │ │ ├── test.py
│ │ │ └── tests/
│ │ │ └── runtests.sh
│ │ ├── matching.py
│ │ ├── position_encoding.py
│ │ └── utils.py
│ ├── modality_input_mappers/
│ │ ├── __init__.py
│ │ └── hilbert_curve.py
│ ├── optimization/
│ │ ├── optimizer.py
│ │ └── scheduler.py
│ └── registry.py
├── output/
│ └── VIS/
│ ├── cvc/
│ │ └── pvt.py
│ ├── fibroid/
│ │ └── pvt.py
│ └── sunseg/
│ ├── pvt/
│ │ └── pvt.py
│ └── res/
│ └── res.py
├── reorganize_sunseg.py
├── trainers/
│ ├── Trainer.py
│ └── __init__.py
└── utils/
├── __init__.py
└── misc.py
SYMBOL INDEX (978 symbols across 66 files)
FILE: data_schedule/__init__.py
function build_schedule (line 9) | def build_schedule(configs, model_input_mapper, model_input_collate_fn):
function composition (line 88) | def composition(data_dict, mappers):
function evaluate_call (line 95) | def evaluate_call(evaluators, model, output_dir):
function _infinite_indices (line 108) | def _infinite_indices(seed, dataset_length, shuffle=True,):
function infinite_indices (line 118) | def infinite_indices(seed,
FILE: data_schedule/registry.py
class Mapper (line 7) | class Mapper:
method __init__ (line 8) | def __init__(self,
method _call (line 14) | def _call(self, data_dict):
method __call__ (line 17) | def __call__(self, data_dict):
FILE: data_schedule/utils/box_ops.py
function box_cxcywh_to_xyxy (line 9) | def box_cxcywh_to_xyxy(x):
function box_xyxy_to_cxcywh (line 16) | def box_xyxy_to_cxcywh(x):
function box_iou (line 26) | def box_iou(boxes1, boxes2):
function generalized_box_iou (line 42) | def generalized_box_iou(boxes1, boxes2):
function masks_to_boxes (line 66) | def masks_to_boxes(masks):
FILE: data_schedule/utils/sampler.py
class TrainRandomSampler_ByEpoch (line 15) | class TrainRandomSampler_ByEpoch(Sampler[int]):
method __init__ (line 16) | def __init__(self,
method __iter__ (line 25) | def __iter__(self):
method __len__ (line 36) | def __len__(self) -> int:
method set_epoch (line 39) | def set_epoch(self, epoch: int) -> None:
class Train_InfiniteSampler_Distributed (line 49) | class Train_InfiniteSampler_Distributed(Sampler[T_co]):
method __init__ (line 50) | def __init__(self,
method set_iter_first_sample_idx (line 62) | def set_iter_first_sample_idx(self, idx):
method set_iter_last_sample_idx (line 65) | def set_iter_last_sample_idx(self, idx):
method __iter__ (line 68) | def __iter__(self) -> Iterator[T_co]:
class Evaluate_ExactSampler_Distributed (line 72) | class Evaluate_ExactSampler_Distributed(Sampler[T_co]):
method __init__ (line 73) | def __init__(self, dataset) -> None:
method __iter__ (line 81) | def __iter__(self):
method __len__ (line 84) | def __len__(self):
class TrainRandomSampler_ByEpoch_Distributed (line 88) | class TrainRandomSampler_ByEpoch_Distributed(Sampler[T_co]):
method __init__ (line 89) | def __init__(self,
method __iter__ (line 104) | def __iter__(self) -> Iterator[T_co]:
method __len__ (line 124) | def __len__(self) -> int:
method set_epoch (line 127) | def set_epoch(self, epoch: int) -> None:
class InferenceSampler (line 139) | class InferenceSampler(Sampler):
method __init__ (line 147) | def __init__(self, size: int):
method _get_local_indices (line 159) | def _get_local_indices(total_size, world_size, rank):
method __iter__ (line 168) | def __iter__(self):
method __len__ (line 171) | def __len__(self):
FILE: data_schedule/utils/segmentation.py
function bounding_box_from_mask (line 4) | def bounding_box_from_mask(mask):
FILE: data_schedule/vis/apis.py
class VIS_Dataset (line 2) | class VIS_Dataset:
class VIS_Aug_CallbackAPI (line 6) | class VIS_Aug_CallbackAPI:
class VIS_Evaluator_OutAPI_EvalFn_API (line 10) | class VIS_Evaluator_OutAPI_EvalFn_API:
class VIS_TrainAPI_clipped_video (line 14) | class VIS_TrainAPI_clipped_video:
class VIS_EvalAPI_clipped_video_request_ann (line 18) | class VIS_EvalAPI_clipped_video_request_ann:
class VIS_FrameSampler_InputOutput_API (line 22) | class VIS_FrameSampler_InputOutput_API:
class GetFrames (line 26) | class GetFrames:
FILE: data_schedule/vis/evaluator_fast.py
class VIS_Evaluator_FrameFast (line 17) | class VIS_Evaluator_FrameFast:
method __init__ (line 18) | def __init__(self,
method visualize_path (line 41) | def visualize_path(self, meta_idxs, visualize, evaluator_path):
method __call__ (line 45) | def __call__(self, model, output_dir):
FILE: data_schedule/vis/evaluator_utils.py
function register_vis_metric (line 3) | def register_vis_metric(fn):
function vis_metric_entrypoint (line 11) | def vis_metric_entrypoint(vis_metric_name):
function _prepare_data (line 21) | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
class Smeasure (line 31) | class Smeasure(object):
method __init__ (line 32) | def __init__(self, length, alpha: float = 0.5):
method step (line 36) | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
method cal_sm (line 42) | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
method object (line 53) | def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
method s_object (line 60) | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
method region (line 66) | def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
method centroid (line 79) | def centroid(self, matrix: np.ndarray) -> tuple:
method divide_with_xy (line 97) | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
method ssim (line 120) | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
method get_results (line 142) | def get_results(self):
function mask_dice_iou (line 152) | def mask_dice_iou(frame_pred, dataset_meta, **kwargs):
function mask_dice_iou_sen_mae_smeasure (line 173) | def mask_dice_iou_sen_mae_smeasure(frame_pred, dataset_meta, **kwargs):
function web (line 223) | def web(frame_pred, output_dir, **kwargs):
FILE: data_schedule/vis/fibroid/evals.py
function fibroid_other_medi (line 17) | def fibroid_other_medi(model_preds,
function fibroid_mask_dice_iou (line 82) | def fibroid_mask_dice_iou(frame_pred, dataset_meta, **kwargs):
function fibroid_metric_aggregator (line 104) | def fibroid_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_m...
FILE: data_schedule/vis/fibroid/fibroid_dataset.py
function fibroid_train (line 16) | def fibroid_train(step_size, # none / int; 0, 6, 13, 19 ...
function fibroid_evaluate (line 44) | def fibroid_evaluate(eval_video_ids,
FILE: data_schedule/vis/fibroid/fibroid_utils.py
function get_frames (line 21) | def get_frames(frames_path, video_id, frames):
function get_frames_mask (line 25) | def get_frames_mask(mask_path, video_id, frames):
FILE: data_schedule/vis/fibroid/metrics.py
function get_stats (line 62) | def get_stats(
function _get_stats_multiclass (line 165) | def _get_stats_multiclass(
function _get_stats_multilabel (line 206) | def _get_stats_multilabel(
function _handle_zero_division (line 228) | def _handle_zero_division(x, zero_division):
function _compute_metric (line 238) | def _compute_metric(
function _fbeta_score (line 312) | def _fbeta_score(tp, fp, fn, tn, beta=1):
function _iou_score (line 319) | def _iou_score(tp, fp, fn, tn):
function _accuracy (line 323) | def _accuracy(tp, fp, fn, tn):
function _sensitivity (line 327) | def _sensitivity(tp, fp, fn, tn):
function _specificity (line 331) | def _specificity(tp, fp, fn, tn):
function _balanced_accuracy (line 335) | def _balanced_accuracy(tp, fp, fn, tn):
function _dice (line 338) | def _dice(tp, fp, fn, tn):
function _positive_predictive_value (line 341) | def _positive_predictive_value(tp, fp, fn, tn):
function _negative_predictive_value (line 345) | def _negative_predictive_value(tp, fp, fn, tn):
function _false_negative_rate (line 349) | def _false_negative_rate(tp, fp, fn, tn):
function _false_positive_rate (line 353) | def _false_positive_rate(tp, fp, fn, tn):
function _false_discovery_rate (line 357) | def _false_discovery_rate(tp, fp, fn, tn):
function _false_omission_rate (line 361) | def _false_omission_rate(tp, fp, fn, tn):
function _positive_likelihood_ratio (line 365) | def _positive_likelihood_ratio(tp, fp, fn, tn):
function _negative_likelihood_ratio (line 369) | def _negative_likelihood_ratio(tp, fp, fn, tn):
function fbeta_score (line 373) | def fbeta_score(
function f1_score (line 397) | def f1_score(
function dice (line 419) | def dice(
function iou_score (line 440) | def iou_score(
function accuracy (line 462) | def accuracy(
function sensitivity (line 484) | def sensitivity(
function specificity (line 506) | def specificity(
function balanced_accuracy (line 528) | def balanced_accuracy(
function positive_predictive_value (line 550) | def positive_predictive_value(
function negative_predictive_value (line 572) | def negative_predictive_value(
function false_negative_rate (line 594) | def false_negative_rate(
function false_positive_rate (line 616) | def false_positive_rate(
function false_discovery_rate (line 638) | def false_discovery_rate(
function false_omission_rate (line 660) | def false_omission_rate(
function positive_likelihood_ratio (line 682) | def positive_likelihood_ratio(
function negative_likelihood_ratio (line 704) | def negative_likelihood_ratio(
FILE: data_schedule/vis/mapper.py
class VIS_Video_EvalMapper (line 21) | class VIS_Video_EvalMapper(VIS_EvalMapper):
method __init__ (line 22) | def __init__(self,
method _call (line 36) | def _call(self, data_dict):
class VIS_Video_or_Step_To_Clip_TrainMapper (line 60) | class VIS_Video_or_Step_To_Clip_TrainMapper(VIS_TrainMapper):
method __init__ (line 61) | def __init__(self,
method _call (line 79) | def _call(self, data_dict):
FILE: data_schedule/vis/mapper_utils.py
class VIS_Mapper (line 8) | class VIS_Mapper(Mapper):
method __init__ (line 9) | def __init__(self,
class VIS_TrainMapper (line 15) | class VIS_TrainMapper(VIS_Mapper):
method __init__ (line 17) | def __init__(self,
method map_to_frame_targets (line 26) | def map_to_frame_targets(self, clip_targets):
method map_global_targets_to_local_targets (line 51) | def map_global_targets_to_local_targets(self, ret):
class VIS_EvalMapper (line 62) | class VIS_EvalMapper(VIS_Mapper):
method __init__ (line 63) | def __init__(self,
FILE: data_schedule/vis/polyp/evals.py
function polyp_metric_aggregator (line 10) | def polyp_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_met...
FILE: data_schedule/vis/polyp/polyp_dataset.py
function polyp_train (line 14) | def polyp_train(step_size,
function polyp_evaluate (line 46) | def polyp_evaluate(eval_video_ids,
FILE: data_schedule/vis/polyp/polyp_utils.py
function get_frames (line 103) | def get_frames(frames_path, video_id, frames):
function get_frames_mask (line 105) | def get_frames_mask(mask_path, video_id, frames):
FILE: data_schedule/vis/vis_aug_eval.py
class RandomResize (line 10) | class RandomResize:
method __init__ (line 11) | def __init__(self, sizes, max_size=None):
method __call__ (line 16) | def __call__(self, ret):
class VideoToPIL (line 46) | class VideoToPIL:
method __call__ (line 47) | def __call__(self, ret):
class VideoToTensor (line 56) | class VideoToTensor:
method __call__ (line 57) | def __call__(self, ret):
class WeakPolyP_EvalAug (line 70) | class WeakPolyP_EvalAug:
method __init__ (line 71) | def __init__(self, configs) -> None:
method __call__ (line 77) | def __call__(self, ret):
FILE: data_schedule/vis/vis_aug_train.py
class RandomRotate90 (line 22) | class RandomRotate90:
method __init__ (line 23) | def __init__(self) -> None:
method __call__ (line 28) | def __call__(self, ret):
class ComputeBox (line 51) | class ComputeBox:
method __call__ (line 52) | def __call__(self, ret):
class VideoToTensor (line 64) | class VideoToTensor:
method __call__ (line 65) | def __call__(self, ret):
class Compose (line 71) | class Compose:
method __init__ (line 72) | def __init__(self, transforms):
method __call__ (line 75) | def __call__(self, ret):
method __repr__ (line 80) | def __repr__(self):
class WeakPolyP_TrainAug (line 90) | class WeakPolyP_TrainAug:
method __init__ (line 91) | def __init__(self, configs) -> None:
method __call__ (line 102) | def __call__(self, ret):
class WeakPolyP_TrainAug_RotateImageToClip (line 131) | class WeakPolyP_TrainAug_RotateImageToClip:
method __init__ (line 132) | def __init__(self, configs) -> None:
method apply_random_sequence_shuffle (line 149) | def apply_random_sequence_shuffle(self, images, instance_masks):
method __call__ (line 156) | def __call__(self, ret):
class ImageToSeqAugmenter (line 197) | class ImageToSeqAugmenter(object):
method __init__ (line 198) | def __init__(self, perspective=True, affine=True, motion_blur=True,
method condense_masks (line 233) | def condense_masks(instance_masks):
method expand_masks (line 241) | def expand_masks(condensed_mask, num_instances):
method __call__ (line 244) | def __call__(self, image, masks=None, boxes=None): # n h w
FILE: data_schedule/vis/vis_aug_utils.py
function get_size_with_aspect_ratio (line 11) | def get_size_with_aspect_ratio(image_size, size, max_size=None):
function get_tgt_size (line 31) | def get_tgt_size(image_size, size, max_size=None):
function pil_torch_to_numpy (line 37) | def pil_torch_to_numpy(video, masks, has_ann, float_image=True):
function numpy_to_pil_torch (line 62) | def numpy_to_pil_torch(video, masks, has_ann):
FILE: data_schedule/vis/vis_frame_sampler.py
class Naive_ReferenceFrame_FrameSampler (line 14) | class Naive_ReferenceFrame_FrameSampler:
method __init__ (line 15) | def __init__(self, sampler_configs, dataset_meta, **kwargs):
method __call__ (line 26) | def __call__(self,
FILE: handle_vps.py
function get_frames_mask (line 80) | def get_frames_mask(mask_path, video_id, frames):
FILE: main.py
function _highlight (line 13) | def _highlight(code, filename):
class _ColorfulFormatter (line 26) | class _ColorfulFormatter(logging.Formatter):
method __init__ (line 27) | def __init__(self, *args, **kwargs):
method formatMessage (line 33) | def formatMessage(self, record):
function set_logging_file (line 46) | def set_logging_file(output_dir, file_name, mode='a'):
function init_process_group_and_set_device (line 63) | def init_process_group_and_set_device(world_size, process_id, device_id):
function run (line 81) | def run(rank, configs, world_size):
FILE: models/VIS/BackboneEncoderDecoder_WithScaleConsistency.py
class BackboneEncoderDecoder_WithScaleConsistency (line 24) | class BackboneEncoderDecoder_WithScaleConsistency(nn.Module):
method __init__ (line 25) | def __init__(
method device (line 52) | def device(self):
method model_preds (line 55) | def model_preds(self, videos, video_aux_dict,):
method forward (line 72) | def forward(self, batch_dict):
method sample (line 89) | def sample(self, batch_dict):
method get_optim_params_group (line 121) | def get_optim_params_group(model, configs):
function backbone_encoder_decoder_withScaleConsistency (line 180) | def backbone_encoder_decoder_withScaleConsistency(configs, device):
FILE: models/VIS/aux_mapper.py
class AUXMapper_v1 (line 12) | class AUXMapper_v1:
method __init__ (line 13) | def __init__(self, aux_configs):
method mapper (line 23) | def mapper(self, data_dict, mode,):
method collate (line 40) | def collate(self, batch_dict, mode, max_stride):
method collate_video_dict (line 82) | def collate_video_dict(self, batch_dict, max_stride):
method collate_frame_targets (line 105) | def collate_frame_targets(self, frame_targets, frame_has_ann, pad_H, p...
method collate_targets (line 124) | def collate_targets(self, targets, pad_H, pad_W, pad_T):
method visualize_input_target_for_debug_data (line 144) | def visualize_input_target_for_debug_data(self, ret):
FILE: models/backbone/pvtv2.py
class DWConv (line 8) | class DWConv(nn.Module):
method __init__ (line 9) | def __init__(self, dim):
method forward (line 13) | def forward(self, x, H, W):
class Mlp (line 20) | class Mlp(nn.Module):
method __init__ (line 21) | def __init__(self, in_features, hidden_features):
method forward (line 27) | def forward(self, x, H, W):
class Attention (line 33) | class Attention(nn.Module):
method __init__ (line 34) | def __init__(self, dim, num_heads, sr_ratio):
method forward (line 48) | def forward(self, x, H, W):
class Block (line 67) | class Block(nn.Module):
method __init__ (line 68) | def __init__(self, dim, num_heads, mlp_ratio, drop_path, sr_ratio):
method forward (line 76) | def forward(self, x, H, W):
class OverlapPatchEmbed (line 81) | class OverlapPatchEmbed(nn.Module):
method __init__ (line 82) | def __init__(self, patch_size, stride, in_chans, embed_dim):
method forward (line 87) | def forward(self, x):
class PVT (line 94) | class PVT(nn.Module):
method __init__ (line 95) | def __init__(self, embed_dims, mlp_ratios, depths, snapshot, sr_ratios...
method no_weight_decay (line 129) | def no_weight_decay(self):
method forward (line 132) | def forward(self, x):
class PVT_V2 (line 168) | class PVT_V2(nn.Module):
method __init__ (line 169) | def __init__(self, configs) -> None:
method forward (line 189) | def forward(self, x):
method num_parameters (line 212) | def num_parameters(self):
class Video2D_PVT_V2 (line 216) | class Video2D_PVT_V2(nn.Module):
method __init__ (line 217) | def __init__(self, configs) -> None:
method forward (line 231) | def forward(self, x):
method num_parameters (line 240) | def num_parameters(self):
FILE: models/backbone/res2net.py
class Bottle2neck (line 10) | class Bottle2neck(nn.Module):
method __init__ (line 12) | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWi...
method forward (line 33) | def forward(self, x):
class Res2Net (line 51) | class Res2Net(nn.Module):
method __init__ (line 52) | def __init__(self, layers, snapshot, baseWidth=26, scale=4):
method _make_layer (line 77) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 92) | def forward(self, x):
method initialize (line 101) | def initialize(self):
class Res2Net_50_EachFrame (line 105) | class Res2Net_50_EachFrame(nn.Module):
method __init__ (line 106) | def __init__(self, configs) -> None:
method forward (line 129) | def forward(self, x):
method num_parameters (line 142) | def num_parameters(self):
FILE: models/backbone/utils.py
class VideoMultiscale_Shape (line 2) | class VideoMultiscale_Shape:
method __init__ (line 3) | def __init__(self, temporal_stride, spatial_stride, dim) -> None:
method set_multiscale_same_dim (line 9) | def set_multiscale_same_dim(shape_by_dim, same_dim):
class ImageMultiscale_Shape (line 16) | class ImageMultiscale_Shape:
method __init__ (line 17) | def __init__(self, spatial_stride, dim) -> None:
FILE: models/decoder/mask2former_video.py
function calculate_uncertainty (line 26) | def calculate_uncertainty(logits):
class Video_SetMatchingLoss (line 43) | class Video_SetMatchingLoss(nn.Module):
method __init__ (line 44) | def __init__(self,
method device (line 61) | def device(self,):
method compute_loss (line 64) | def compute_loss(self,
method matching (line 131) | def matching(self, layer_out, targets):
method loss_mask_dice_ce (line 193) | def loss_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_...
method loss_point_mask_dice_ce (line 207) | def loss_point_mask_dice_ce(self, outputs, targets, indices, num_objs,...
method loss_class_ce (line 256) | def loss_class_ce(self, outputs, targets, indices, num_objs, loss_extr...
method loss_box_l1_giou (line 274) | def loss_box_l1_giou(self, outputs, targets, indices, num_objs, loss_e...
method _get_src_permutation_idx (line 296) | def _get_src_permutation_idx(self, indices):
method _get_tgt_permutation_idx (line 302) | def _get_tgt_permutation_idx(self, indices):
class Video_MaskedAttn_MultiscaleMaskDecoder_v3 (line 309) | class Video_MaskedAttn_MultiscaleMaskDecoder_v3(nn.Module):
method __init__ (line 310) | def __init__(self,
method device (line 362) | def device(self,):
method get_memories_and_mask_features (line 365) | def get_memories_and_mask_features(self, multiscales):
method forward (line 375) | def forward(self,
method forward_heads (line 425) | def forward_heads(self, temporal_query_feats, mask_features, attn_mas...
method compute_loss (line 448) | def compute_loss(self, outputs, targets, video_aux_dict, **kwargs):
FILE: models/encoder/input_projs.py
class VideoConv3d_TextLinear (line 10) | class VideoConv3d_TextLinear(nn.Module):
method __init__ (line 15) | def __init__(self,
method forward (line 50) | def forward(self, multiscales, text_dict):
class VideoConv2d_TextLinear (line 70) | class VideoConv2d_TextLinear(nn.Module):
method __init__ (line 75) | def __init__(self,
method forward (line 109) | def forward(self, multiscales, text_dict):
class ImageConv_MultiscaleProj (line 131) | class ImageConv_MultiscaleProj(nn.Module):
method __init__ (line 132) | def __init__(self,
method forward (line 159) | def forward(self, multiscales):
class Video2D_ImageConv_MultiscaleProj (line 171) | class Video2D_ImageConv_MultiscaleProj(nn.Module):
method __init__ (line 172) | def __init__(self,
method forward (line 180) | def forward(self, multiscales):
class VideoConv_MultiscaleProj (line 190) | class VideoConv_MultiscaleProj(nn.Module):
method __init__ (line 191) | def __init__(self,
method forward (line 218) | def forward(self, multiscales):
class FrameQueryLinear_TextLinear (line 234) | class FrameQueryLinear_TextLinear(nn.Module):
method __init__ (line 235) | def __init__(self,
method forward (line 253) | def forward(self, frame_query, text_dict):
class VideoConv3d_FrameQueryLinear_TextLinear (line 268) | class VideoConv3d_FrameQueryLinear_TextLinear(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 289) | def forward(self, mask_feat, frame_query, text_dict):
class VideoConv3d_FrameQueryLinear (line 302) | class VideoConv3d_FrameQueryLinear(nn.Module):
method __init__ (line 307) | def __init__(self,
method forward (line 342) | def forward(self, multiscales, frame_queries):
class FrameQueryLinear (line 358) | class FrameQueryLinear(nn.Module):
method __init__ (line 359) | def __init__(self,
method forward (line 370) | def forward(self, frame_query):
FILE: models/encoder/localGlobal.py
class MSDeformAttnTransformerEncoderOnly (line 21) | class MSDeformAttnTransformerEncoderOnly(nn.Module):
method __init__ (line 22) | def __init__(self, d_model=256, nhead=8,
method _reset_parameters (line 57) | def _reset_parameters(self):
method get_valid_ratio (line 66) | def get_valid_ratio(self, mask):
method forward (line 75) | def forward(self,
class MSDeformAttnTransformerEncoderLayer (line 115) | class MSDeformAttnTransformerEncoderLayer(nn.Module):
method __init__ (line 116) | def __init__(self,
method with_pos_embed (line 160) | def with_pos_embed(tensor, pos):
method forward_ffn (line 163) | def forward_ffn(self, src):
method get_attn_mask (line 170) | def get_attn_mask(self, frame_query_feats, src, spatial_shapes, level_...
method forward (line 185) | def forward(self,
class MSDeformAttnTransformerEncoder (line 238) | class MSDeformAttnTransformerEncoder(nn.Module):
method __init__ (line 239) | def __init__(self,
method get_reference_points (line 252) | def get_reference_points(spatial_shapes, valid_ratios, device):
method forward (line 267) | def forward(self,
class Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal (line 305) | class Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal(nn.Mo...
method __init__ (line 306) | def __init__(
method forward (line 369) | def forward(self,
FILE: models/encoder/neighborhood_qk.py
class NeighborhoodAttention2D_qk (line 13) | class NeighborhoodAttention2D_qk(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 57) | def forward(self,
method extra_repr (line 99) | def extra_repr(self) -> str:
class NA_qk_Layer (line 107) | class NA_qk_Layer(nn.Module):
method __init__ (line 109) | def __init__(self, d_model, configs):
method forward (line 122) | def forward(self, tgt=None, scale_shapes=None, level_start_idxs=None, ...
class NA_qk_Layer_v2 (line 146) | class NA_qk_Layer_v2(nn.Module):
method __init__ (line 148) | def __init__(self, configs):
method forward (line 158) | def forward(self,
FILE: models/encoder/ops/attention.py
function exists (line 11) | def exists(val):
function uniq (line 15) | def uniq(arr):
function default (line 19) | def default(val, d):
function max_neg_value (line 25) | def max_neg_value(t):
function init_ (line 29) | def init_(tensor):
class GEGLU (line 37) | class GEGLU(nn.Module):
method __init__ (line 38) | def __init__(self, dim_in, dim_out):
method forward (line 42) | def forward(self, x):
class FeedForward (line 47) | class FeedForward(nn.Module):
method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 63) | def forward(self, x):
function zero_module (line 67) | def zero_module(module):
function Normalize (line 76) | def Normalize(in_channels):
class LinearAttention (line 80) | class LinearAttention(nn.Module):
method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 88) | def forward(self, x):
class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels):
method forward (line 126) | def forward(self, x):
class CrossAttention (line 152) | class CrossAttention(nn.Module):
method __init__ (line 153) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 170) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 195) | class BasicTransformerBlock(nn.Module):
method __init__ (line 196) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 206) | def forward(self, x, context=None):
function _is_power_of_2 (line 213) | def _is_power_of_2(n):
class DeformAttn (line 218) | class DeformAttn(nn.Module):
method __init__ (line 219) | def __init__(self,
method _reset_parameters (line 253) | def _reset_parameters(self):
method forward (line 269) | def forward(self, query, reference_points,
class ContextuallSelfAttention (line 323) | class ContextuallSelfAttention(nn.Module):
method __init__ (line 324) | def __init__(self,
method forward (line 355) | def forward(self,
class BasicTransformerBlock_v2 (line 412) | class BasicTransformerBlock_v2(nn.Module):
method __init__ (line 413) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 420) | def forward(self, x, context=None):
class SpatialTransformer (line 427) | class SpatialTransformer(nn.Module):
method __init__ (line 435) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 459) | def forward(self, x, context=None):
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/ms_deform_attn_func.py
class MSDeformAttnFunction (line 21) | class MSDeformAttnFunction(Function):
method forward (line 23) | def forward(ctx, value, value_spatial_shapes, value_level_start_index,...
method backward (line 32) | def backward(ctx, grad_output):
function ms_deform_attn_core_pytorch (line 41) | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_lo...
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/ms_deform_attn.py
function _is_power_of_2 (line 25) | def _is_power_of_2(n):
class MSDeformAttn (line 31) | class MSDeformAttn(nn.Module):
method __init__ (line 32) | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
method _reset_parameters (line 63) | def _reset_parameters(self):
method forward (line 79) | def forward(self, query, reference_points, input_flatten, input_spatia...
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py
class MSDeformAttnFunction (line 21) | class MSDeformAttnFunction(Function):
method forward (line 23) | def forward(ctx, value, value_spatial_shapes, value_level_start_index,...
method backward (line 32) | def backward(ctx, grad_output):
function ms_deform_attn_core_pytorch (line 41) | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_lo...
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py
function _is_power_of_2 (line 25) | def _is_power_of_2(n):
class MSDeformAttn (line 31) | class MSDeformAttn(nn.Module):
method __init__ (line 32) | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
method _reset_parameters (line 63) | def _reset_parameters(self):
method forward (line 79) | def forward(self, query, reference_points, input_flatten, input_spatia...
FILE: models/encoder/ops/functions/ms_deform_attn_func.py
class MSDeformAttnFunction (line 21) | class MSDeformAttnFunction(Function):
method forward (line 23) | def forward(ctx, value, value_spatial_shapes, value_level_start_index,...
method backward (line 32) | def backward(ctx, grad_output):
function ms_deform_attn_core_pytorch (line 41) | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_lo...
FILE: models/encoder/ops/modules/frame_query_ss2d.py
class SS2D (line 20) | class SS2D(nn.Module):
method __init__ (line 21) | def __init__(
method dt_init (line 93) | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0...
method A_log_init (line 120) | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
method D_init (line 137) | def D_init(d_inner, copies=1, device=None, merge=True):
method forward_corev0 (line 148) | def forward_corev0(self, x: torch.Tensor):
method forward (line 187) | def forward(self, x: torch.Tensor, **kwargs):
class SS2D_FrameQuery (line 205) | class SS2D_FrameQuery(nn.Module):
method __init__ (line 206) | def __init__(self, configs,):
method forward (line 225) | def forward(self,
class FrameQuery_SS2DLayer (line 247) | class FrameQuery_SS2DLayer(nn.Module):
method __init__ (line 248) | def __init__(self,
method forward (line 264) | def forward(self,
class TemporalQuery_CrossSelf (line 281) | class TemporalQuery_CrossSelf(nn.Module):
method __init__ (line 282) | def __init__(self, configs) -> None:
method forward (line 299) | def forward(self,
class SS2D_FrameQuery_v2 (line 326) | class SS2D_FrameQuery_v2(nn.Module):
method __init__ (line 327) | def __init__(self, configs,):
method forward (line 348) | def forward(self,
class FrameQuery_SS2DLayer_v2 (line 375) | class FrameQuery_SS2DLayer_v2(nn.Module):
method __init__ (line 376) | def __init__(self,
method forward (line 390) | def forward(self,
class Hilbert_2DSelectiveScan (line 405) | class Hilbert_2DSelectiveScan(nn.Module):
method __init__ (line 406) | def __init__(
method dt_init (line 476) | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0...
method A_log_init (line 503) | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
method D_init (line 520) | def D_init(d_inner, copies=1, device=None, merge=True):
method forward_corev0 (line 531) | def forward_corev0(self, x: torch.Tensor, hilbert_curve):
method forward (line 653) | def forward(self, x: torch.Tensor, hilbert_curve, **kwargs):
class SS2D_FrameQuery_hilbert (line 672) | class SS2D_FrameQuery_hilbert(nn.Module):
method __init__ (line 673) | def __init__(self, configs,):
method forward (line 695) | def forward(self,
class FrameQuery_SS2DLayer_hilbert (line 724) | class FrameQuery_SS2DLayer_hilbert(nn.Module):
method __init__ (line 725) | def __init__(self,
method forward (line 738) | def forward(self,
FILE: models/encoder/ops/modules/ms_deform_attn.py
function _is_power_of_2 (line 25) | def _is_power_of_2(n):
class MSDeformAttn (line 31) | class MSDeformAttn(nn.Module):
method __init__ (line 32) | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
method _reset_parameters (line 63) | def _reset_parameters(self):
method forward (line 79) | def forward(self, query, reference_points, input_flatten, input_spatia...
FILE: models/encoder/ops/setup.py
function get_extensions (line 23) | def get_extensions():
FILE: models/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp
function ms_deform_attn_cpu_forward (line 17) | at::Tensor
function ms_deform_attn_cpu_backward (line 29) | std::vector<at::Tensor>
FILE: models/encoder/ops/src/ms_deform_attn.h
function im2col_step (line 27) | int im2col_step)
FILE: models/encoder/ops/src/vision.cpp
function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: models/encoder/ops/test.py
function check_forward_equal_with_pytorch_double (line 32) | def check_forward_equal_with_pytorch_double():
function check_forward_equal_with_pytorch_float (line 48) | def check_forward_equal_with_pytorch_float():
function check_gradient_numerical (line 63) | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_...
FILE: models/layers/anyc_trans.py
class MLP (line 10) | class MLP(nn.Module):
method __init__ (line 12) | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
method forward (line 18) | def forward(self, x):
class Linear_NormAct (line 23) | class Linear_NormAct(nn.Linear):
method __init__ (line 24) | def __init__(self, *args, **kwargs):
method forward (line 38) | def forward(self, x):
class Conv2d_NormAct (line 46) | class Conv2d_NormAct(torch.nn.Conv2d):
method __init__ (line 47) | def __init__(self, *args, **kwargs):
method forward (line 65) | def forward(self, x):
class Conv3d_NormAct (line 74) | class Conv3d_NormAct(torch.nn.Conv3d):
method __init__ (line 75) | def __init__(self, *args, **kwargs):
method forward (line 89) | def forward(self, x):
FILE: models/layers/decoder_layers.py
class SelfAttentionLayer (line 10) | class SelfAttentionLayer(nn.Module):
method __init__ (line 12) | def __init__(self, d_model, nhead, dropout=0.0,
method _reset_parameters (line 25) | def _reset_parameters(self):
method with_pos_embed (line 30) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
method forward_post (line 33) | def forward_post(self, tgt,
method forward_pre (line 48) | def forward_pre(self, tgt,
method forward (line 63) | def forward(self, tgt,
class CrossAttentionLayer (line 73) | class CrossAttentionLayer(nn.Module):
method __init__ (line 75) | def __init__(self, d_model, nhead, dropout=0.0,
method _reset_parameters (line 88) | def _reset_parameters(self):
method with_pos_embed (line 93) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
method forward_post (line 96) | def forward_post(self, tgt, memory,
method forward_pre (line 112) | def forward_pre(self, tgt, memory,
method forward (line 126) | def forward(self, tgt, memory,
class FFNLayer (line 138) | class FFNLayer(nn.Module):
method __init__ (line 140) | def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
method _reset_parameters (line 155) | def _reset_parameters(self):
method with_pos_embed (line 160) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
method forward_post (line 163) | def forward_post(self, tgt):
method forward_pre (line 169) | def forward_pre(self, tgt):
method forward (line 175) | def forward(self, tgt):
FILE: models/layers/gilbert/demo/script.js
function download (line 25) | function download(filename, text) {
function dl_svg (line 37) | function dl_svg() {
function update_color (line 50) | function update_color() {
function update_wh (line 57) | function update_wh(w,h) {
function update_preset (line 71) | function update_preset() {
function update_num (line 106) | function update_num() {
function draw_curve (line 123) | function draw_curve() {
function init (line 214) | function init() {
FILE: models/layers/gilbert/demo/two.js
function decomposeMatrix (line 92) | function decomposeMatrix(matrix, b, c, d, e, f) {
function setMatrix (line 112) | function setMatrix(matrix) {
function getComputedMatrix (line 115) | function getComputedMatrix(object, matrix) {
function lerp (line 142) | function lerp(a, b, t) {
function getPoT (line 146) | function getPoT(value) {
function mod (line 153) | function mod(v, l) {
function toFixed (line 161) | function toFixed(v) {
method constructor (line 184) | constructor() {
method addEventListener (line 186) | addEventListener(name, handler) {
method on (line 192) | on() {
method bind (line 195) | bind() {
method removeEventListener (line 198) | removeEventListener(name, handler) {
method off (line 227) | off() {
method unbind (line 230) | unbind() {
method dispatchEvent (line 233) | dispatchEvent(name) {
method trigger (line 246) | trigger() {
method listen (line 249) | listen(obj, name, handler) {
method ignore (line 262) | ignore(obj, name, handler) {
method constructor (line 325) | constructor(x = 0, y = 0) {
method add (line 333) | static add(v1, v2) {
method sub (line 336) | static sub(v1, v2) {
method subtract (line 339) | static subtract(v1, v2) {
method ratioBetween (line 342) | static ratioBetween(v1, v2) {
method angleBetween (line 345) | static angleBetween(v1, v2) {
method distanceBetween (line 355) | static distanceBetween(v1, v2) {
method distanceBetweenSquared (line 358) | static distanceBetweenSquared(v1, v2) {
method set (line 363) | set(x, y) {
method copy (line 368) | copy(v) {
method clear (line 373) | clear() {
method clone (line 378) | clone() {
method add (line 381) | add(x, y) {
method addSelf (line 398) | addSelf(v) {
method sub (line 401) | sub(x, y) {
method subtract (line 418) | subtract() {
method subSelf (line 421) | subSelf(v) {
method subtractSelf (line 424) | subtractSelf(v) {
method multiply (line 427) | multiply(x, y) {
method multiplySelf (line 444) | multiplySelf(v) {
method multiplyScalar (line 447) | multiplyScalar(s) {
method divide (line 450) | divide(x, y) {
method divideSelf (line 473) | divideSelf(v) {
method divideScalar (line 476) | divideScalar(s) {
method negate (line 479) | negate() {
method dot (line 482) | dot(v) {
method length (line 485) | length() {
method lengthSquared (line 488) | lengthSquared() {
method normalize (line 491) | normalize() {
method distanceTo (line 494) | distanceTo(v) {
method distanceToSquared (line 497) | distanceToSquared(v) {
method setLength (line 502) | setLength(l) {
method equals (line 505) | equals(v, eps) {
method lerp (line 509) | lerp(v, t) {
method isZero (line 514) | isZero(eps) {
method toString (line 518) | toString() {
method toObject (line 521) | toObject() {
method rotate (line 524) | rotate(radians) {
method constructor (line 554) | constructor(x = 0, y = 0, ax = 0, ay = 0, bx = 0, by = 0, command = Comm...
method makeBroadcast (line 565) | static makeBroadcast(scope) {
method copy (line 573) | copy(v) {
method clone (line 607) | clone() {
method toObject (line 610) | toObject() {
method toString (line 627) | toString() {
function getComponentOnCubicBezier (line 796) | function getComponentOnCubicBezier(t, a, b, c, d) {
function subdivide (line 800) | function subdivide(x1, y1, x2, y2, x3, y3, x4, y4, limit) {
function getCurveLength (line 815) | function getCurveLength(x1, y1, x2, y2, x3, y3, x4, y4, limit) {
function getCurveBoundingBox (line 833) | function getCurveBoundingBox(x1, y1, x2, y2, x3, y3, x4, y4) {
function integrate (line 890) | function integrate(f, a, b, n) {
function getCurveFromPoints (line 898) | function getCurveFromPoints(points, closed2) {
function getControlPoints (line 911) | function getControlPoints(a, b, c) {
function getReflection (line 944) | function getReflection(a, b, relative) {
function getAnchorsFromArcData (line 950) | function getAnchorsFromArcData(center, xAxisRotation, rx, ry, ts, td, cc...
function getBackingStoreRatio (line 969) | function getBackingStoreRatio(ctx) {
function getRatio (line 972) | function getRatio(ctx) {
function isArrayLike (line 978) | function isArrayLike(collection) {
method constructor (line 1038) | constructor() {
method flagReset (line 1044) | flagReset() {
method constructor (line 1096) | constructor(a, b, c, d, e, f) {
method Multiply (line 1107) | static Multiply(A, B, C) {
method set (line 1135) | set(a, b, c, d, e, f, g, h, i) {
method copy (line 1159) | copy(m) {
method identity (line 1172) | identity() {
method multiply (line 1184) | multiply(a, b, c, d, e, f, g, h, i) {
method inverse (line 1229) | inverse(out) {
method scale (line 1254) | scale(sx, sy) {
method rotate (line 1261) | rotate(Number2) {
method translate (line 1266) | translate(x, y) {
method skewX (line 1269) | skewX(Number2) {
method skewY (line 1273) | skewY(Number2) {
method toString (line 1277) | toString(fullMatrix) {
method toTransformArray (line 1282) | toTransformArray(fullMatrix, output) {
method toArray (line 1337) | toArray(fullMatrix, output) {
method toObject (line 1392) | toObject() {
method clone (line 1398) | clone() {
method constructor (line 1427) | constructor() {
method renderer (line 1443) | get renderer() {
method renderer (line 1446) | set renderer(v) {
method translation (line 1449) | get translation() {
method translation (line 1452) | set translation(v) {
method addTo (line 1455) | addTo(group) {
method remove (line 1459) | remove() {
method clone (line 1466) | clone(parent) {
method _update (line 1481) | _update(bubbles) {
method flagReset (line 1500) | flagReset() {
function FlagMatrix (line 1589) | function FlagMatrix() {
method _bound (line 1596) | get _bound() {
method _bound (line 1599) | set _bound(v) {
method addEventListener (line 1602) | addEventListener() {
method on (line 1605) | on() {
method bind (line 1608) | bind() {
method removeEventListener (line 1611) | removeEventListener() {
method off (line 1614) | off() {
method unbind (line 1617) | unbind() {
method dispatchEvent (line 1620) | dispatchEvent() {
method trigger (line 1623) | trigger() {
method listen (line 1626) | listen() {
method ignore (line 1629) | ignore() {
method constructor (line 1632) | constructor() {
method pop (line 1642) | pop() {
method shift (line 1647) | shift() {
method push (line 1652) | push() {
method unshift (line 1657) | unshift() {
method splice (line 1662) | splice() {
method sort (line 1672) | sort() {
method reverse (line 1677) | reverse() {
method indexOf (line 1682) | indexOf() {
method map (line 1685) | map(func, scope) {
method constructor (line 1704) | constructor(children) {
method attach (line 1711) | attach(children) {
method detach (line 1720) | detach(children) {
method constructor (line 1755) | constructor(children) {
method InsertChildren (line 1765) | static InsertChildren(children) {
method RemoveChildren (line 1770) | static RemoveChildren(children) {
method OrderChildren (line 1775) | static OrderChildren(children) {
method clone (line 1778) | clone(parent) {
method toObject (line 1800) | toObject() {
method corner (line 1818) | corner() {
method center (line 1831) | center() {
method getById (line 1848) | getById(id) {
method getByClassName (line 1865) | getByClassName(className) {
method getByType (line 1881) | getByType(type) {
method add (line 1897) | add(objects) {
method remove (line 1916) | remove(objects) {
method getBoundingClientRect (line 1939) | getBoundingClientRect(shallow) {
method noFill (line 1983) | noFill() {
method noStroke (line 1989) | noStroke() {
method subdivide (line 1995) | subdivide() {
method _update (line 2002) | _update() {
method flagReset (line 2035) | flagReset() {
function replaceParent (line 2274) | function replaceParent(child, newParent) {
method constructor (line 3040) | constructor(params) {
method setSize (line 3052) | setSize(width, height, ratio) {
method render (line 3066) | render() {
function renderArcEstimate (line 3083) | function renderArcEstimate(ctx, ox, oy, rx, ry, startAngle, endAngle, cl...
function svgAngle (line 3118) | function svgAngle(ux, uy, vx, vy) {
function isDefaultMatrix (line 3127) | function isDefaultMatrix(m) {
function fallbackRequest (line 3174) | function fallbackRequest(callback, element) {
method constructor (line 3204) | constructor(message) {
method constructor (line 3213) | constructor() {
method add (line 3215) | add(id, obj) {
method remove (line 3219) | remove(id) {
method get (line 3223) | get(id) {
method contains (line 3226) | contains(id) {
function contains (line 3232) | function contains(path, t) {
function getIdByLength (line 3248) | function getIdByLength(path, target) {
function getCurveLength2 (line 3264) | function getCurveLength2(a, b, limit) {
function getSubdivisions (line 3286) | function getSubdivisions(a, b, limit) {
method constructor (line 3317) | constructor(offset, color, opacity) {
method clone (line 3328) | clone(parent) {
method toObject (line 3338) | toObject() {
method flagReset (line 3345) | flagReset() {
method constructor (line 3403) | constructor(stops) {
method clone (line 3420) | clone(parent) {
method toObject (line 3433) | toObject() {
method _update (line 3444) | _update() {
method flagReset (line 3450) | flagReset() {
function FlagStops (line 3497) | function FlagStops() {
function BindStops (line 3500) | function BindStops(items) {
function UnbindStops (line 3508) | function UnbindStops(items) {
method constructor (line 3522) | constructor(x1, y1, x2, y2, stops) {
method clone (line 3544) | clone(parent) {
method toObject (line 3563) | toObject() {
method _update (line 3569) | _update() {
method flagReset (line 3575) | flagReset() {
function FlagEndPoints (line 3614) | function FlagEndPoints() {
method constructor (line 3626) | constructor(cx, cy, r, stops, fx, fy) {
method clone (line 3651) | clone(parent) {
method toObject (line 3671) | toObject() {
method _update (line 3680) | _update() {
method flagReset (line 3686) | flagReset() {
function FlagCenter (line 3735) | function FlagCenter() {
function FlagFocal (line 3738) | function FlagFocal() {
method constructor (line 3766) | constructor(src, callback) {
method getAbsoluteURL (line 3799) | static getAbsoluteURL(path) {
method loadHeadlessBuffer (line 3806) | static loadHeadlessBuffer(texture, loaded) {
method getTag (line 3810) | static getTag(image) {
method getImage (line 3813) | static getImage(src) {
method load (line 3835) | static load(texture, callback) {
method clone (line 3855) | clone() {
method toObject (line 3862) | toObject() {
method _update (line 3870) | _update() {
method flagReset (line 3886) | flagReset() {
function FlagOffset (line 4067) | function FlagOffset() {
function FlagScale (line 4070) | function FlagScale() {
method constructor (line 4110) | constructor(vertices, closed2, curved, manual) {
method clone (line 4141) | clone(parent) {
method toObject (line 4164) | toObject() {
method noFill (line 4190) | noFill() {
method noStroke (line 4194) | noStroke() {
method corner (line 4198) | corner() {
method center (line 4219) | center() {
method getBoundingClientRect (line 4234) | getBoundingClientRect(shallow) {
method getPointAt (line 4320) | getPointAt(t, obj) {
method plot (line 4421) | plot() {
method subdivide (line 4431) | subdivide(limit) {
method _updateLength (line 4483) | _updateLength(limit, silent) {
method _update (line 4514) | _update() {
method flagReset (line 4601) | flagReset() {
function FlagVertices (line 4839) | function FlagVertices() {
function BindVertices (line 4846) | function BindVertices(items) {
function UnbindVertices (line 4853) | function UnbindVertices(items) {
function FlagFill (line 4860) | function FlagFill() {
function FlagStroke (line 4863) | function FlagStroke() {
method constructor (line 4869) | constructor(x, y, width, height) {
method _update (line 4896) | _update() {
method flagReset (line 4914) | flagReset() {
method clone (line 4919) | clone(parent) {
method toObject (line 4938) | toObject() {
method constructor (line 5005) | constructor(path, ox, oy, cols, rows, frameRate) {
method play (line 5030) | play(firstFrame, lastFrame, onLastFrame) {
method pause (line 5051) | pause() {
method stop (line 5055) | stop() {
method clone (line 5060) | clone(parent) {
method toObject (line 5078) | toObject() {
method _update (line 5090) | _update() {
method flagReset (line 5154) | flagReset() {
method constructor (line 5227) | constructor(ox, oy, r, resolution) {
method _update (line 5248) | _update() {
method flagReset (line 5276) | flagReset() {
method clone (line 5281) | clone(parent) {
method toObject (line 5300) | toObject() {
method constructor (line 5332) | constructor(x, y, rx, ry, resolution) {
method _update (line 5359) | _update() {
method flagReset (line 5387) | flagReset() {
method clone (line 5392) | clone(parent) {
method toObject (line 5414) | toObject() {
method constructor (line 5450) | constructor(x1, y1, x2, y2) {
method constructor (line 5503) | constructor(x, y, width, height, radius) {
method _update (line 5545) | _update() {
method flagReset (line 5613) | flagReset() {
method clone (line 5618) | clone(parent) {
method toObject (line 5640) | toObject() {
function FlagRadius (line 5690) | function FlagRadius() {
method constructor (line 5735) | constructor(message, x, y, styles) {
method Measure (line 5762) | static Measure(text) {
method clone (line 5787) | clone(parent) {
method toObject (line 5804) | toObject() {
method noFill (line 5819) | noFill() {
method noStroke (line 5823) | noStroke() {
method getBoundingClientRect (line 5828) | getBoundingClientRect(shallow) {
method flagReset (line 5874) | flagReset() {
function FlagFill2 (line 6087) | function FlagFill2() {
function FlagStroke2 (line 6090) | function FlagStroke2() {
function getAlignment (line 6107) | function getAlignment(anchor2) {
function getBaseline (line 6110) | function getBaseline(node) {
function getTagName (line 6115) | function getTagName(tag) {
function applyTransformsToVector (line 6118) | function applyTransformsToVector(transforms, vector2) {
function extractCSSText (line 6129) | function extractCSSText(text, styles) {
function getSvgStyles (line 6145) | function getSvgStyles(node) {
function getSvgAttributes (line 6161) | function getSvgAttributes(node) {
function applySvgViewBox (line 6172) | function applySvgViewBox(node, value) {
function applySvgAttributes (line 6213) | function applySvgAttributes(node, elem, parentStyles) {
function updateDefsCache (line 6445) | function updateDefsCache(node, defsCache) {
function getScene (line 6456) | function getScene(node) {
function xhr (line 7075) | function xhr(path, callback) {
method constructor (line 7103) | constructor(paths, ox, oy, frameRate) {
method play (line 7127) | play(firstFrame, lastFrame, onLastFrame) {
method pause (line 7148) | pause() {
method stop (line 7152) | stop() {
method clone (line 7157) | clone(parent) {
method toObject (line 7173) | toObject() {
method _update (line 7185) | _update() {
method flagReset (line 7247) | flagReset() {
function FlagTextures (line 7298) | function FlagTextures() {
function BindTextures (line 7301) | function BindTextures(items) {
function UnbindTextures (line 7308) | function UnbindTextures(items) {
function GenerateTexture (line 7315) | function GenerateTexture(obj) {
method constructor (line 7333) | constructor(x, y, ir, or, sa, ea, res) {
method _update (line 7363) | _update() {
method flagReset (line 7467) | flagReset() {
method clone (line 7472) | clone(parent) {
method toObject (line 7496) | toObject() {
method constructor (line 7574) | constructor(vertices) {
method clone (line 7598) | clone(parent) {
method toObject (line 7621) | toObject() {
method subdivide (line 7646) | subdivide(limit) {
method _update (line 7666) | _update() {
method flagReset (line 7693) | flagReset() {
method constructor (line 7867) | constructor(x, y, radius, sides) {
method _update (line 7889) | _update() {
method flagReset (line 7914) | flagReset() {
method clone (line 7919) | clone(parent) {
method toObject (line 7940) | toObject() {
method constructor (line 8007) | constructor(x, y, innerRadius, outerRadius, sides) {
method _update (line 8038) | _update() {
method flagReset (line 8064) | flagReset() {
method clone (line 8069) | clone(parent) {
method toObject (line 8091) | toObject() {
method constructor (line 8893) | constructor(params) {
method setSize (line 8904) | setSize(width, height) {
method render (line 8913) | render() {
method constructor (line 10040) | constructor(params) {
method setSize (line 10107) | setSize(width, height, ratio) {
method render (line 10132) | render() {
method _bound (line 10154) | get _bound() {
method _bound (line 10157) | set _bound(v) {
method addEventListener (line 10160) | addEventListener() {
method on (line 10163) | on() {
method bind (line 10166) | bind() {
method removeEventListener (line 10169) | removeEventListener() {
method off (line 10172) | off() {
method unbind (line 10175) | unbind() {
method dispatchEvent (line 10178) | dispatchEvent() {
method trigger (line 10181) | trigger() {
method listen (line 10184) | listen() {
method ignore (line 10187) | ignore() {
method constructor (line 10198) | constructor(options) {
method appendTo (line 10263) | appendTo(elem) {
method play (line 10274) | play() {
method pause (line 10279) | pause() {
method setPlaying (line 10283) | setPlaying(p) {
method release (line 10286) | release(obj) {
method update (line 10324) | update() {
method render (line 10345) | render() {
method add (line 10349) | add(objects) {
method remove (line 10356) | remove(objects) {
method clear (line 10363) | clear() {
method makeLine (line 10367) | makeLine(x1, y1, x2, y2) {
method makeArrow (line 10372) | makeArrow(x1, y1, x2, y2, size) {
method makeRectangle (line 10405) | makeRectangle(x, y, width, height) {
method makeRoundedRectangle (line 10410) | makeRoundedRectangle(x, y, width, height, sides) {
method makeCircle (line 10415) | makeCircle(x, y, radius, resolution) {
method makeEllipse (line 10420) | makeEllipse(x, y, rx, ry, resolution) {
method makeStar (line 10425) | makeStar(x, y, outerRadius, innerRadius, sides) {
method makeCurve (line 10430) | makeCurve(points) {
method makePolygon (line 10450) | makePolygon(x, y, radius, sides) {
method makeArcSegment (line 10455) | makeArcSegment(x, y, innerRadius, outerRadius, startAngle, endAngle, res...
method makePoints (line 10468) | makePoints(p) {
method makePath (line 10486) | makePath(p) {
method makeText (line 10509) | makeText(message, x, y, styles) {
method makeLinearGradient (line 10514) | makeLinearGradient(x1, y1, x2, y2) {
method makeRadialGradient (line 10520) | makeRadialGradient(x1, y1, radius) {
method makeSprite (line 10526) | makeSprite(pathOrTexture, x, y, columns, rows, frameRate, autostart) {
method makeImageSequence (line 10534) | makeImageSequence(pathsOrTextures, x, y, frameRate, autostart) {
method makeTexture (line 10542) | makeTexture(pathOrSource, callback) {
method makeGroup (line 10546) | makeGroup(objects) {
method interpret (line 10555) | interpret(svg2, shallow, add) {
method load (line 10569) | load(pathOrSVGContent, callback) {
function fitToWindow (line 10635) | function fitToWindow() {
function fitToParent (line 10641) | function fitToParent() {
function updateDimensions (line 10652) | function updateDimensions(width, height) {
function loop (line 10658) | function loop() {
FILE: models/layers/gilbert/gilbert2d.py
function gilbert2d (line 6) | def gilbert2d(width, height):
function gilbert2d_widthBigger (line 18) | def gilbert2d_widthBigger(width, height):
function sgn (line 27) | def sgn(x):
function generate2d (line 31) | def generate2d(x, y, ax, ay, bx, by):
FILE: models/layers/gilbert/gilbert3d.py
function gilbert3d (line 6) | def gilbert3d(width, height, depth):
function sgn (line 32) | def sgn(x):
function generate3d (line 36) | def generate3d(x, y, z,
FILE: models/layers/gilbert/gilbert_d2xy.py
function gilbert_d2xy (line 5) | def gilbert_d2xy(idx, w, h):
function sgn (line 16) | def sgn(x):
function gilbert_d2xy_r (line 19) | def gilbert_d2xy_r(dst_idx, cur_idx, x,y, ax,ay, bx,by):
FILE: models/layers/gilbert/gilbert_d2xyz.py
function gilbert_d2xyz (line 5) | def gilbert_d2xyz(idx, width, height, depth):
function sgn (line 33) | def sgn(x):
function gilbert_d2xyz_r (line 36) | def gilbert_d2xyz_r(dst_idx, cur_idx,
FILE: models/layers/gilbert/gilbert_xy2d.py
function gilbert_xy2d (line 5) | def gilbert_xy2d(x, y, w, h):
function sgn (line 17) | def sgn(x):
function in_bounds (line 21) | def in_bounds(x, y, x_s, y_s, ax, ay, bx, by):
function gilbert_xy2d_r (line 39) | def gilbert_xy2d_r(cur_idx, x_dst, y_dst, x, y, ax, ay, bx, by):
FILE: models/layers/gilbert/gilbert_xyz2d.py
function gilbert_xyz2d (line 5) | def gilbert_xyz2d(x, y, z, width, height, depth):
function sgn (line 34) | def sgn(x):
function in_bounds (line 37) | def in_bounds(x, y, z, x_s, y_s, z_s, ax, ay, az, bx, by, bz, cx, cy, cz):
function gilbert_xyz2d_r (line 61) | def gilbert_xyz2d_r(cur_idx,
FILE: models/layers/gilbert/ports/gilbert.c
function gilbert_xy2d (line 19) | int gilbert_xy2d(int x, int y, int w, int h) {
function gilbert_d2xy (line 26) | int gilbert_d2xy(int *x, int *y, int idx,int w,int h) {
function gilbert_xyz2d (line 50) | int gilbert_xyz2d(int x, int y, int z, int width, int height, int depth) {
function gilbert_d2xyz (line 74) | int gilbert_d2xyz(int *x, int *y, int *z, int idx, int width, int height...
function sgn (line 105) | static int sgn(int x) {
function in_bounds2 (line 111) | int in_bounds2(int x, int y,
function in_bounds3 (line 137) | int in_bounds3(int x, int y, int z,
function gilbert_d2xy_r (line 174) | int gilbert_d2xy_r(int dst_idx, int cur_idx,
function gilbert_xy2d_r (line 279) | int gilbert_xy2d_r(int cur_idx,
function gilbert_d2xyz_r (line 362) | int gilbert_d2xyz_r(int dst_idx, int cur_idx,
function gilbert_xyz2d_r (line 612) | int gilbert_xyz2d_r(int cur_idx,
function main (line 869) | int main(int argc, char **argv) {
FILE: models/layers/gilbert/ports/gilbert.js
function sgn (line 15) | function sgn(x) {
function in_bounds2 (line 21) | function in_bounds2(p, s, a, b) {
function in_bounds3 (line 37) | function in_bounds3(p, s, a, b, c) {
function gilbert_xy2d (line 58) | function gilbert_xy2d(x,y,w,h) {
function gilbert_d2xy (line 71) | function gilbert_d2xy(idx,w,h) {
function gilbert_xyz2d (line 83) | function gilbert_xyz2d(x,y,z,w,h,d) {
function gilbert_d2xyz (line 99) | function gilbert_d2xyz(idx,w,h,d) {
function gilbert_d2xy_r (line 113) | function gilbert_d2xy_r( dst_idx,cur_idx, p, a, b) {
function gilbert_xy2d_r (line 187) | function gilbert_xy2d_r(idx, q, p, a, b) {
function gilbert_xyz2d_r (line 254) | function gilbert_xyz2d_r(cur_idx, q, p, a, b, c) {
function gilbert_d2xyz_r (line 414) | function gilbert_d2xyz_r(dst_idx, cur_idx, p, a, b, c) {
function _main (line 596) | function _main(argv) {
FILE: models/layers/matching.py
function dice_loss (line 7) | def dice_loss(inputs, targets, num_boxes):
function ber_loss (line 24) | def ber_loss(inputs, targets, num_boxes):
function ce_mask_loss (line 32) | def ce_mask_loss(inputs, targets, num_boxes,
function sigmoid_focal_loss (line 51) | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, ...
function batch_sigmoid_focal_loss (line 78) | def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma...
function batch_dice_loss (line 94) | def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
function batch_sigmoid_ce_loss (line 111) | def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor,):
function get_src_permutation_idx (line 137) | def get_src_permutation_idx(indices):
function get_tgt_permutation_idx (line 143) | def get_tgt_permutation_idx(indices):
FILE: models/layers/position_encoding.py
class PositionEmbeddingSine (line 11) | class PositionEmbeddingSine(nn.Module):
method __init__ (line 16) | def __init__(self, num_pos_feats=64, temperature=10000, normalize=Fals...
method forward (line 27) | def forward(self, tensor_list: NestedTensor):
class PositionEmbeddingLearned (line 50) | class PositionEmbeddingLearned(nn.Module):
method __init__ (line 54) | def __init__(self, num_pos_feats=256):
method reset_parameters (line 60) | def reset_parameters(self):
method forward (line 64) | def forward(self, tensor_list: NestedTensor):
class PositionEmbeddingSine1D (line 78) | class PositionEmbeddingSine1D(nn.Module):
method __init__ (line 83) | def __init__(self, temperature=10000, normalize=True, scale=None):
method forward (line 93) | def forward(self, mask, hidden_dim):
class PositionEmbeddingSine3D (line 113) | class PositionEmbeddingSine3D(nn.Module):
method __init__ (line 119) | def __init__(self, num_pos_feats=64, temperature=10000, normalize=Fals...
method forward (line 130) | def forward(self, x, mask=None):
class PositionEmbeddingSine2D (line 161) | class PositionEmbeddingSine2D(nn.Module):
method __init__ (line 166) | def __init__(self, temperature=10000, normalize=True, scale=None):
method forward (line 176) | def forward(self, mask, hidden_dim: int):
class PositionEmbeddingLearned1D (line 203) | class PositionEmbeddingLearned1D(nn.Module):
method __init__ (line 207) | def __init__(self, num_pos_feats=256):
method reset_parameters (line 212) | def reset_parameters(self):
method forward (line 215) | def forward(self, tensor_list: NestedTensor):
function build_position_encoding (line 228) | def build_position_encoding(hidden_dim=None, position_embedding_name='2d'):
FILE: models/layers/utils.py
function zero_module (line 6) | def zero_module(module):
function _get_clones (line 14) | def _get_clones(module, N):
function _get_activation_fn (line 17) | def _get_activation_fn(activation):
function _get_activation_layer (line 29) | def _get_activation_layer(activation):
function pad_1d_feats (line 40) | def pad_1d_feats(feat_list):
FILE: models/modality_input_mappers/hilbert_curve.py
class HilbertCurve_FrameQuery (line 7) | class HilbertCurve_FrameQuery:
method __init__ (line 8) | def __init__(self,
method mapper (line 14) | def mapper(self, video):
method collate (line 19) | def collate(self, list_of_haosen, batch_videos):
FILE: models/optimization/optimizer.py
class GradientClipType (line 13) | class GradientClipType(Enum):
function maybe_add_full_model_gradient_clipping (line 18) | def maybe_add_full_model_gradient_clipping(optim, configs):
function _create_gradient_clipper (line 36) | def _create_gradient_clipper(cfg) -> _GradientClipper:
function _generate_optimizer_class_with_gradient_clipping (line 55) | def _generate_optimizer_class_with_gradient_clipping(
function maybe_add_gradient_clipping (line 88) | def maybe_add_gradient_clipping(
function get_optimizer (line 122) | def get_optimizer(params, configs):
FILE: models/optimization/scheduler.py
function build_scheduler (line 7) | def build_scheduler(configs, optimizer):
FILE: models/registry.py
function register_model (line 3) | def register_model(fn):
function model_entrypoint (line 11) | def model_entrypoint(model_name):
FILE: trainers/Trainer.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, configs):
method train (line 68) | def train(self):
method save_ckpt (line 99) | def save_ckpt(self):
method evaluate (line 129) | def evaluate(self):
method load_ckpt (line 164) | def load_ckpt(self,
method _log (line 202) | def _log(self,
method device (line 256) | def device(self):
method iteration_dir (line 260) | def iteration_dir(self):
method epoch (line 264) | def epoch(self):
method log_header (line 270) | def log_header(self, iteration_time, sample_idxs):
method visualize_path (line 275) | def visualize_path(self, meta_idxs, visualize):
method register_metric_logger (line 278) | def register_metric_logger(self, log_keys):
FILE: utils/misc.py
class SmoothedValue (line 31) | class SmoothedValue(object):
method __init__ (line 36) | def __init__(self, window_size=1, fmt='{value:.6f}', handler='value'):
method update (line 45) | def update(self, value, n=1):
method synchronize_between_processes (line 50) | def synchronize_between_processes(self):
method median (line 64) | def median(self):
method avg (line 69) | def avg(self):
method global_avg (line 74) | def global_avg(self):
method max (line 78) | def max(self):
method value (line 82) | def value(self):
method wandb_log_property (line 86) | def wandb_log_property(self):
method __str__ (line 94) | def __str__(self):
function all_gather (line 103) | def all_gather(data):
function all_gather_cpu (line 145) | def all_gather_cpu(data):
function reduce_dict (line 181) | def reduce_dict(input_dict, average=True):
function reduce_scalar (line 207) | def reduce_scalar(input, average=True):
class MetricLogger (line 226) | class MetricLogger(object):
method __init__ (line 227) | def __init__(self, delimiter="\t"):
method update (line 231) | def update(self, **kwargs):
method __getattr__ (line 238) | def __getattr__(self, attr):
method __str__ (line 246) | def __str__(self):
method to_dict (line 254) | def to_dict(self):
method synchronize_between_processes (line 260) | def synchronize_between_processes(self):
method add_meter (line 264) | def add_meter(self, name, meter):
method log_every (line 267) | def log_every(self, iterable, print_freq, header=None):
function get_sha (line 322) | def get_sha():
function collate_fn (line 342) | def collate_fn(batch):
function _max_by_axis (line 348) | def _max_by_axis(the_list):
class NestedTensor (line 357) | class NestedTensor(object):
method __init__ (line 358) | def __init__(self, tensors, mask: Optional[Tensor]):
method to (line 362) | def to(self, device):
method decompose (line 373) | def decompose(self):
method __repr__ (line 376) | def __repr__(self):
function nested_tensor_from_tensor_list (line 405) | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
function nested_tensor_from_tensor_list_visiblility (line 423) | def nested_tensor_from_tensor_list_visiblility(tensor_list: List[Tensor]...
function nested_tensor_from_tensor_list_with_stride (line 459) | def nested_tensor_from_tensor_list_with_stride(tensor_list: List[Tensor]...
function _get_nearest_scale_number (line 484) | def _get_nearest_scale_number(num, scale):
function nested_tensor_from_videos_list (line 491) | def nested_tensor_from_videos_list(videos_list: List[Tensor]):
function _onnx_nested_tensor_from_tensor_list (line 516) | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> N...
function setup_for_distributed (line 544) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 559) | def is_dist_avail_and_initialized():
function get_world_size (line 567) | def get_world_size():
function get_rank (line 573) | def get_rank():
function is_main_process (line 579) | def is_main_process():
function save_on_master (line 583) | def save_on_master(*args, **kwargs):
function init_distributed_mode (line 588) | def init_distributed_mode(args):
function accuracy (line 614) | def accuracy(output, target, topk=(1,)):
function interpolate (line 632) | def interpolate(input, size=None, scale_factor=None, mode="nearest", ali...
function inverse_sigmoid (line 651) | def inverse_sigmoid(x, eps=1e-5):
function to_device (line 657) | def to_device(sample, device):
function nested_tensor_from_videos_list_with_stride (line 667) | def nested_tensor_from_videos_list_with_stride(videos_list, max_stride):
Condensed preview — 117 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (987K chars).
[
{
"path": ".gitignore",
"chars": 213,
"preview": "**/__pycache__/**\n**/wandb/**\n**/.vscode/**\n*.out\n*.err\n*.zip\n*.tar\n*.pth\nstdout_train.txt\nstdout_eval.txt\nstdout_visual"
},
{
"path": "README.md",
"chars": 7325,
"preview": "\n## LGRNet: Local-Global Reciprocal Network for Video Polyp Segmentation [`Paper`](https://arxiv.org/abs/2407.05703) | ["
},
{
"path": "assets/DATA.md",
"chars": 2363,
"preview": "\n\n\n\n\n\n# Data Preparation\n\n## UFUV (Private): \nplease email the second author for UFUV dataset if you want, I have no abs"
},
{
"path": "assets/INSTALL.md",
"chars": 1628,
"preview": "# Install\n## Requirements\nWe test the codes in the following environments\n\n- CUDA 12.1\n- Python 3.10.13\n- Pytorch 2.1.1\n"
},
{
"path": "assets/MODEL_ZOO.md",
"chars": 0,
"preview": ""
},
{
"path": "data_schedule/__init__.py",
"chars": 8143,
"preview": "\nimport os\n\nif os.getenv('CURRENT_TASK') == 'VIS':\n from . import vis\nelse:\n raise ValueError()\n\ndef build_schedul"
},
{
"path": "data_schedule/registry.py",
"chars": 793,
"preview": "from detectron2.utils.registry import Registry\n\n\nEVALUATOR_REGISTRY = Registry('EVALUATOR')\nMAPPER_REGISTRY = Registry('"
},
{
"path": "data_schedule/utils/box_ops.py",
"chars": 2629,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nUtilities for bounding box manipulation and G"
},
{
"path": "data_schedule/utils/sampler.py",
"chars": 5759,
"preview": "\nimport math\nimport torch.distributed as dist\nfrom typing import TypeVar, Optional, Iterator \nT_co = TypeVar('T_c"
},
{
"path": "data_schedule/utils/segmentation.py",
"chars": 438,
"preview": "\nimport torch\n\ndef bounding_box_from_mask(mask):\n if not mask.any():\n return torch.zeros([4]).float()\n rows"
},
{
"path": "data_schedule/vis/__init__.py",
"chars": 163,
"preview": "from . import polyp\nfrom . import mapper \nfrom . import evaluator_fast\nfrom . import vis_aug_eval \nfrom . import vis_aug"
},
{
"path": "data_schedule/vis/apis.py",
"chars": 339,
"preview": "\nclass VIS_Dataset:\n \"\"\"\n \"\"\"\n\nclass VIS_Aug_CallbackAPI:\n \"\"\"\n \"\"\"\n\nclass VIS_Evaluator_OutAPI_EvalFn_API:\n"
},
{
"path": "data_schedule/vis/evaluator_fast.py",
"chars": 6700,
"preview": "\nimport os\nfrom tqdm import tqdm\nfrom functools import partial\nimport torch\nimport detectron2.utils.comm as comm\nfrom ut"
},
{
"path": "data_schedule/vis/evaluator_utils.py",
"chars": 8238,
"preview": "_vis_metric_entrypoints = {}\n\ndef register_vis_metric(fn):\n vis_metric_name = fn.__name__\n if vis_metric_name in _"
},
{
"path": "data_schedule/vis/fibroid/__init__.py",
"chars": 81,
"preview": "# 注册fibrois数据集\nfrom . import fibroid_dataset\n\n# 注册fibroid评估标准\nfrom . import evals"
},
{
"path": "data_schedule/vis/fibroid/evals.py",
"chars": 5105,
"preview": "from data_schedule.vis.evaluator_utils import register_vis_metric\nimport os\nfrom glob import glob\nfrom tqdm import tqdm\n"
},
{
"path": "data_schedule/vis/fibroid/fibroid_dataset.py",
"chars": 6438,
"preview": "from typing import Optional, Union\nimport json\nimport os\nfrom functools import partial\nimport numpy as np\nimport torch\ni"
},
{
"path": "data_schedule/vis/fibroid/fibroid_utils.py",
"chars": 3433,
"preview": "import wandb\nimport plotly.express as px\nimport logging\nimport os\nimport numpy as np\nimport torch\nimport json\nfrom jobli"
},
{
"path": "data_schedule/vis/fibroid/metrics.py",
"chars": 25014,
"preview": "import warnings\nfrom typing import Optional, List, Tuple, Union\nimport torch\n\n\"\"\"Various metrics based on Type I and Typ"
},
{
"path": "data_schedule/vis/mapper.py",
"chars": 5013,
"preview": "\nimport json\nimport os\nfrom typing import List\nimport copy\nfrom functools import partial\nimport random\nimport numpy as n"
},
{
"path": "data_schedule/vis/mapper_utils.py",
"chars": 3109,
"preview": "from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY, VIS_TRAIN_AUG_REGISTRY\nimport torch\nfrom copy import deepcopy as dcopy"
},
{
"path": "data_schedule/vis/polyp/__init__.py",
"chars": 49,
"preview": "\nfrom . import polyp_dataset\nfrom . import evals\n"
},
{
"path": "data_schedule/vis/polyp/evals.py",
"chars": 1603,
"preview": "from data_schedule.vis.evaluator_utils import register_vis_metric\nimport os\nimport torch\nimport detectron2.utils.comm as"
},
{
"path": "data_schedule/vis/polyp/polyp_dataset.py",
"chars": 6164,
"preview": "from typing import Optional, Union\nimport json\nimport os\nfrom functools import partial\nimport numpy as np\nimport torch\ni"
},
{
"path": "data_schedule/vis/polyp/polyp_utils.py",
"chars": 4086,
"preview": "\nimport os\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nSET_NAME = ['polyp_train', \n 'polyp_hard_seen_"
},
{
"path": "data_schedule/vis/vis_aug_eval.py",
"chars": 2932,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\nimport random\nimport torch\nimport torchvision.tra"
},
{
"path": "data_schedule/vis/vis_aug_train.py",
"chars": 11342,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\nimport random\nfrom PIL import Image\nimport torch\n"
},
{
"path": "data_schedule/vis/vis_aug_utils.py",
"chars": 2508,
"preview": "from detectron2.utils.registry import Registry\nimport torch\nimport numpy as np\nimport torchvision.transforms.functional "
},
{
"path": "data_schedule/vis/vis_frame_sampler.py",
"chars": 3391,
"preview": "\nfrom detectron2.utils.registry import Registry\nimport random\nimport numpy as np\nimport torch\nimport logging\nfrom detect"
},
{
"path": "handle_vps.py",
"chars": 4908,
"preview": "\nimport cv2\nimport numpy as np\nimport os\nimport shutil\nfrom PIL import Image\nimport torch\nfrom tqdm import tqdm\ndataset_"
},
{
"path": "main.py",
"chars": 7842,
"preview": "import os\nimport argparse\nimport logging\nimport importlib\nfrom trainers import task_to_trainer\nimport detectron2.utils.c"
},
{
"path": "models/VIS/BackboneEncoderDecoder_WithScaleConsistency.py",
"chars": 10286,
"preview": "import matplotlib.pyplot as plt\nfrom typing import Any, Optional, List, Dict, Set, Callable\nimport torch\nimport torch.nn"
},
{
"path": "models/VIS/__init__.py",
"chars": 165,
"preview": "from . import BackboneEncoderDecoder_WithScaleConsistency\nfrom .. import modality_input_mappers\nfrom .. import backbone\n"
},
{
"path": "models/VIS/aux_mapper.py",
"chars": 7254,
"preview": "\nimport torch\n\nfrom torch.nn import functional as F\nfrom models.registry import register_model\nfrom data_schedule.utils."
},
{
"path": "models/__init__.py",
"chars": 149,
"preview": "import os\nfrom .registry import model_entrypoint\n\nif os.getenv('CURRENT_TASK') == 'VIS':\n from . import VIS\nelse:\n "
},
{
"path": "models/backbone/__init__.py",
"chars": 29,
"preview": "\nfrom . import res2net, pvtv2"
},
{
"path": "models/backbone/pvtv2.py",
"chars": 10045,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.models.layers import DropPath\nfrom timm.mo"
},
{
"path": "models/backbone/res2net.py",
"chars": 6254,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport os\nimport torch.nn.functional as F\nfrom detectron2.modeling import"
},
{
"path": "models/backbone/utils.py",
"chars": 725,
"preview": "\nclass VideoMultiscale_Shape:\n def __init__(self, temporal_stride, spatial_stride, dim) -> None:\n self.tempora"
},
{
"path": "models/decoder/__init__.py",
"chars": 32,
"preview": "\nfrom . import mask2former_video"
},
{
"path": "models/decoder/mask2former_video.py",
"chars": 22909,
"preview": "\n# multi-scale features, b c h w -> module -> obj queries, predictions, b nq c\nimport torch.nn as nn\nfrom models.layers."
},
{
"path": "models/encoder/__init__.py",
"chars": 83,
"preview": "\nfrom . import localGlobal\nfrom . import input_projs\nfrom . import neighborhood_qk\n"
},
{
"path": "models/encoder/input_projs.py",
"chars": 14410,
"preview": "\nimport torch.nn as nn\nfrom detectron2.modeling import META_ARCH_REGISTRY\nfrom models.layers.anyc_trans import Linear_No"
},
{
"path": "models/encoder/localGlobal.py",
"chars": 20075,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport logging\nimport numpy as np\nfrom typing import Callable, Dict, "
},
{
"path": "models/encoder/neighborhood_qk.py",
"chars": 7731,
"preview": "from typing import Optional\n\nimport torch\nfrom torch import nn, Tensor\nfrom torch.nn.functional import pad\nfrom torch.nn"
},
{
"path": "models/encoder/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO",
"chars": 233,
"preview": "Metadata-Version: 2.1\nName: MultiScaleDeformableAttention\nVersion: 1.0\nSummary: PyTorch Wrapper for CUDA Functions of Mu"
},
{
"path": "models/encoder/ops/attention.py",
"chars": 18280,
"preview": "from inspect import isfunction\nimport math\nimport torch\nfrom torch.nn.init import xavier_uniform_, constant_\nimport torc"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/__init__.py",
"chars": 598,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/ms_deform_attn_func.py",
"chars": 3298,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/__init__.py",
"chars": 584,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/ms_deform_attn.py",
"chars": 6351,
"preview": "# Modify for sample points visualization\n# -----------------------------------------------------------------------------"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py",
"chars": 598,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py",
"chars": 3298,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py",
"chars": 584,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py",
"chars": 6351,
"preview": "# Modify for sample points visualization\n# -----------------------------------------------------------------------------"
},
{
"path": "models/encoder/ops/build/temp.linux-x86_64-cpython-311/.ninja_log",
"chars": 564,
"preview": "# ninja log v5\n0\t5344\t1685604027\t/home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xh"
},
{
"path": "models/encoder/ops/build/temp.linux-x86_64-cpython-311/build.ninja",
"chars": 3111,
"preview": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda/bin/nvcc\n\ncflags = -pthread -B /home/xhh/anaconda3/envs/na"
},
{
"path": "models/encoder/ops/functions/__init__.py",
"chars": 598,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/functions/ms_deform_attn_func.py",
"chars": 3298,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/make.sh",
"chars": 593,
"preview": "#!/usr/bin/env bash\n# ------------------------------------------------------------------------------------------------\n#"
},
{
"path": "models/encoder/ops/modules/__init__.py",
"chars": 615,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/modules/frame_query_ss2d.py",
"chars": 33451,
"preview": "\nfrom models.layers.position_encoding import build_position_encoding\nfrom mamba_ssm.ops.selective_scan_interface import "
},
{
"path": "models/encoder/ops/modules/ms_deform_attn.py",
"chars": 6352,
"preview": "# Modify for sample points visualization\n# -----------------------------------------------------------------------------"
},
{
"path": "models/encoder/ops/setup.py",
"chars": 2559,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp",
"chars": 1256,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/src/cpu/ms_deform_attn_cpu.h",
"chars": 1139,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/src/cuda/ms_deform_attn_cuda.cu",
"chars": 7316,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/src/cuda/ms_deform_attn_cuda.h",
"chars": 1140,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh",
"chars": 54694,
"preview": "/*!\n**************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 Se"
},
{
"path": "models/encoder/ops/src/ms_deform_attn.h",
"chars": 1838,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/src/vision.cpp",
"chars": 799,
"preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
},
{
"path": "models/encoder/ops/test.py",
"chars": 4087,
"preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
},
{
"path": "models/layers/anyc_trans.py",
"chars": 3178,
"preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import repeat, rearrange, reduce\nfrom ty"
},
{
"path": "models/layers/decoder_layers.py",
"chars": 6602,
"preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import repeat, rearrange, reduce\nfrom ty"
},
{
"path": "models/layers/gilbert/demo/index.html",
"chars": 4531,
"preview": "<!DOCTYPE html>\n<html>\n<head>\n <meta charset=\"UTF-8\">\n <title>Gilbert Curve</title>\n\n <meta name=\"viewport\" content=\""
},
{
"path": "models/layers/gilbert/demo/normalize.css",
"chars": 7797,
"preview": "/*! normalize.css v3.0.2 | MIT License | git.io/normalize */\n\n/**\n * 1. Set default font family to sans-serif.\n * 2. Pre"
},
{
"path": "models/layers/gilbert/demo/script.js",
"chars": 4719,
"preview": "// SPDX-License-Identifier: BSD-2-Clause\n// Copyright (c) 2024 abetusk\n\nvar info = {\n \"W\" : -1,\n \"H\": -1,\n \"default\":"
},
{
"path": "models/layers/gilbert/demo/skeleton.css",
"chars": 9952,
"preview": "/*\n* Skeleton V2.0.4\n* Copyright 2014, Dave Gamache\n* www.getskeleton.com\n* Free to use under the MIT license.\n* http://"
},
{
"path": "models/layers/gilbert/demo/two.js",
"chars": 329840,
"preview": "/*\nMIT License\n\nCopyright (c) 2012 - 2021 @jonobr1 / http://jono.fyi\n\nPermission is hereby granted, free of charge, to a"
},
{
"path": "models/layers/gilbert/gilbert2d.py",
"chars": 2510,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2018 Jakub Červený\n\n\ndef gilbert2d(width,"
},
{
"path": "models/layers/gilbert/gilbert3d.py",
"chars": 6025,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2018 Jakub Červený\n\n\ndef gilbert3d(width,"
},
{
"path": "models/layers/gilbert/gilbert_d2xy.py",
"chars": 2660,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2024 abetusk\n\ndef gilbert_d2xy(idx, w, h)"
},
{
"path": "models/layers/gilbert/gilbert_d2xyz.py",
"chars": 8742,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2024 abetusk\n\ndef gilbert_d2xyz(idx, widt"
},
{
"path": "models/layers/gilbert/gilbert_xy2d.py",
"chars": 3162,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2024 abetusk\n\ndef gilbert_xy2d(x, y, w, h"
},
{
"path": "models/layers/gilbert/gilbert_xyz2d.py",
"chars": 10682,
"preview": "#!/usr/bin/env python3\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2024 abetusk\n\ndef gilbert_xyz2d(x, y, z, "
},
{
"path": "models/layers/gilbert/plotpath.m",
"chars": 416,
"preview": "# Octave helper function to plot a 2D or 3D colored curve\n\nfunction h = plotpath(P)\n\n x = P(:,1)';\n y = P(:,2)';\n\n"
},
{
"path": "models/layers/gilbert/ports/Makefile",
"chars": 173,
"preview": "\nCC := gcc\nCFLAGS :=\nOPT := -O3\n\nSRCFILES := gilbert.c\n\nall: gilbert\n\ngilbert: gilbert.c\n\t$(CC) gilbert.c -o gilbert $(C"
},
{
"path": "models/layers/gilbert/ports/gilbert.c",
"chars": 26608,
"preview": "// SPDX-License-Identifier: BSD-2-Clause\n// Copyright (c) 2024 abetusk\n\n#include <stdio.h>\n#include <stdlib.h>\n#include "
},
{
"path": "models/layers/gilbert/ports/gilbert.js",
"chars": 19618,
"preview": "// SPDX-License-Identifier: BSD-2-Clause\n// Copyright (c) 2024 abetusk\n\n\"use strict\";\n\n\nvar gilbert = {\n \"xy2d\": gilber"
},
{
"path": "models/layers/gilbert/test.py",
"chars": 0,
"preview": ""
},
{
"path": "models/layers/gilbert/tests/runtests.sh",
"chars": 4167,
"preview": "#!/bin/bash\n#\n# SPDX-License-Identifier: BSD-2-Clause\n# Copyright (c) 2018 abetusk\n\n\nln -f -s ../gilbert2d.py .\nln -f -s"
},
{
"path": "models/layers/matching.py",
"chars": 5490,
"preview": "import torch\nimport torch.nn.functional as F\n\n\n\n\ndef dice_loss(inputs, targets, num_boxes):\n \"\"\"\n Compute the DICE"
},
{
"path": "models/layers/position_encoding.py",
"chars": 10133,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nVarious positional encodings for the transfor"
},
{
"path": "models/layers/utils.py",
"chars": 1568,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport copy\n\ndef zero_module(module):\n \"\"\"\n Zer"
},
{
"path": "models/modality_input_mappers/__init__.py",
"chars": 58,
"preview": "from .hilbert_curve import (\n HilbertCurve_FrameQuery\n)"
},
{
"path": "models/modality_input_mappers/hilbert_curve.py",
"chars": 1014,
"preview": "\nfrom models.registry import MODELITY_INPUT_MAPPER_REGISTRY\nimport logging\nimport torch\nfrom models.layers.gilbert.gilbe"
},
{
"path": "models/optimization/optimizer.py",
"chars": 5247,
"preview": "\nfrom detectron2.solver.build import maybe_add_gradient_clipping\nfrom collections import OrderedDict\nfrom typing import "
},
{
"path": "models/optimization/scheduler.py",
"chars": 673,
"preview": "import torch\nfrom functools import partial\nimport logging\nimport numpy as np\n\n\ndef build_scheduler(configs, optimizer):\n"
},
{
"path": "models/registry.py",
"chars": 530,
"preview": "_model_entrypoints = {}\n\ndef register_model(fn):\n model_name = fn.__name__\n if model_name in _model_entrypoints:\n "
},
{
"path": "output/VIS/cvc/pvt.py",
"chars": 7074,
"preview": "from copy import deepcopy as dcopy\nimport numpy as np\nframe_sampler = {\n 'name': 'VIS_Video_or_Step_To_Clip_TrainMapp"
},
{
"path": "output/VIS/fibroid/pvt.py",
"chars": 5809,
"preview": "from copy import deepcopy as dcopy\nimport numpy as np\nattention_defaults = {\n 'attn': {\n 'dropout': 0.1,\n "
},
{
"path": "output/VIS/sunseg/pvt/pvt.py",
"chars": 7150,
"preview": "from copy import deepcopy as dcopy\nimport numpy as np\nattention_defaults = {\n 'attn': {\n 'dropout': 0.1,\n "
},
{
"path": "output/VIS/sunseg/res/res.py",
"chars": 6216,
"preview": "from copy import deepcopy as dcopy\nimport numpy as np\nattention_defaults = {\n 'attn': {\n 'dropout': 0.1,\n "
},
{
"path": "reorganize_sunseg.py",
"chars": 3658,
"preview": "import os, shutil, glob\nfrom tqdm import tqdm\n\nSUN_root = f\"{os.getenv('DATASET_PATH')}/SUN-SEG/SUN-Positive/\"\nSUNSEG_ro"
},
{
"path": "trainers/Trainer.py",
"chars": 13684,
"preview": "import torch\nimport numpy as np\nimport random\nimport math\nimport logging\nimport time\nimport os\nfrom utils.misc import re"
},
{
"path": "trainers/__init__.py",
"chars": 80,
"preview": "\nfrom .Trainer import Trainer \ntask_to_trainer = {\n 'VIS': Trainer,\n \n}\n\n\n"
},
{
"path": "utils/__init__.py",
"chars": 18,
"preview": "from . import misc"
},
{
"path": "utils/misc.py",
"chars": 25190,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\"\"\"\nMisc functions, including distributed helpers"
}
]
// ... and 9 more files (download for full content)
About this extraction
This page contains the full source code of the bio-mlhui/LGRNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 117 files (13.2 MB), approximately 250.7k tokens, and a symbol index with 978 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.