As an example for ultrasound video, a single frame is too noisy and insufficient for accurate lesion diagnosis. In practice, doctors need to check neighboring frames(local) and collect all visual clues (global) in the video to predict possible lesion region and filter out irrelevent surrounding issues.
In CNP, each token takes the neighborhood tokens (defined by a kernel) in the cyclic frame as attention keys. CNP enables aggregating the local(cyclic) temporal information into one token. In Hilbert Selective Scan, a set of frame bottleneck queries are used to aggreate spatial information from each frame. Then, we use Hilbert Selective Scan to efficiently parse the global temporal context based on these bottleneck queries. The global temporal context is then propagated back to the feature maps by a Distribute layer. Based on Mask2Former, the decoder can output a set of different mask predictions with corresponding confidence score, which also facilitates comprehesive diagnosis.
## Items
1. Installation: Please refer to [INSTALL.md](assets/INSTALL.md) for more details.
2. Data preparation: Please refer to [DATA.md](assets/DATA.md) for more details.
3. Training:
Change PORT_NUM for DDP and make sure the $CURRENT_TASK is 'VIS':
```
export CURRENT_TASK=VIS
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=PORT_NUM
```
Make sure the $PT_PATH and $DATASET_PATH are correctly set during installation and preparing data.
The training on SUN-SEG is conducted using 2 4090-24GB GPUs:
```
CUDA_VISIBLE_DEVICES=0,1 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode train_attmpt
```
4. logs, checkpoints, predictions
| Backbone| Dataset | Dice | mIou | log | ckpt | predictions |
| :----: | :----: | :----: | :----: | :----: | :----: |:----: |
| PVTv2-B2 | SUN-SEG-Train | -- | -- | [log](https://drive.google.com/file/d/17MTOYW73RLbvZS3BLFBZEphY_0JzN6er/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | --
| PVTv2-B2 | SUN-SEG-Hard-Testing | 0.876 | 0.805 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Easy-Testing | 0.875 | 0.810 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Hard-Unseen-Testing | 0.865 | 0.792 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| PVTv2-B2 | SUN-SEG-Easy-Unseen-Testing | 0.853 | 0.783 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing)| [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing)
| Res2Net-50 | SUN-SEG-Hard-Testing | 0.841 | 0.765 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) |
| Res2Net-50 | SUN-SEG-Easy-Testing | 0.843 | 0.774 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) |
| PVTv2-B2 | CVC612V | 0.933 | 0.877 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
| PVTv2-B2 | CVC300TV | 0.916 | 0.852 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
| PVTv2-B2 | CVC612T | 0.875 | 0.814 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) |
5. Evaluate:
Evaluating on SUN-SEG-Easy AND SUN-SEG-Hard using 1 4090-24GPU GPUS (**modify the ckpt_path to the absolute path**):
```
CUDA_VISIBLE_DEVICES=0 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode eval --eval_path ckpt_path
```
## citing
```
@article{xu2024lgrnet,
title={LGRNet: Local-Global Reciprocal Network for Uterine Fibroid Segmentation in Ultrasound Videos},
author={Xu, Huihui and Yang, Yijun and Aviles-Rivero, Angelica I and Yang, Guang and Qin, Jing and Zhu, Lei},
journal={arXiv preprint arXiv:2407.05703},
year={2024}
}
```
## Acknowledgments
- Thanks [Gilbert](https://github.com/jakubcerveny/gilbert) for the implementation of Hilbert curve generation.
- Thanks GPT4 for helping me constructing idea of Hilbert Filling Curve v.s. Zigzag curve
================================================
FILE: assets/DATA.md
================================================
# Data Preparation
## UFUV (Private):
please email the second author for UFUV dataset if you want, I have no absolute power for UFUV
## VPS (Public)
### CVC/Kvasir/Mayo
We follow [PNS-Net](https://github.com/GewelsJI/PNS-Net) to download the CVC/Kvasir/Mayo dataset. The download link is same as [link](https://drive.google.com/file/d/1TyaRy4c4nHFDa3o2bOl4dP5Z7wes7HV2/view?usp=sharing)
Put MICCAI-VPS-dataset.zip in $DATASET_PATH, then run following script to change the directory structure:
```
cd $DATASET_PATH
unzip -qq MICCAI-VPS-dataset.zip
# cd LGRNet directory
# normalize the VPS data structure
python handle_vps.py
```
Now the structure should be like:
```
${DATASET_PATH}
-- MICCAI-VPS-dataset
-- Kvasir-SEG
-- *
-- VPS-TestSet
-- CVC-ColonDB-300
-- *
-- CVC-ClinicDB-612-Valid
-- *
-- CVC-ClinicDB-612-Test
-- *
-- VPS-TrainSet
-- ASU-Mayo_Clinic
-- Train
-- *
-- CVC-ClinicDB-612
-- Train
-- *
-- CVC-ColonDB-300
-- Train
-- *
```
where * means the following structure:
```
-- Frame
-- vid1
-- img file
-- GT
-- vid1
-- mask file
```
### SUN-SEG
Please follow https://github.com/GewelsJI/VPS/blob/main/docs/DATA_PREPARATION.md to email the author for SUN-SEG.
Put part1, part2, annotation in $DATASET_PATH/SUN-SEG
```
# normalize the directory
unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part1.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive
unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part2.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive
tar -xf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation.tar -C $DATASET_PATH/SUN-SEG/
rm -rf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Unseen/Frame
find $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation -name "._.DS_Store" -type f -delete
python reorganize_sunseg.py
```
Now the structure should be like:
```
${DATASET_PATH}
-- SUN-SEG
-- SUN-SEG-Annotation
-- TrainDataset
-- *
-- TestEasyDataset
-- combine
-- *
-- TestHardDataset
-- combine
-- *
```
================================================
FILE: assets/INSTALL.md
================================================
# Install
## Requirements
We test the codes in the following environments
- CUDA 12.1
- Python 3.10.13
- Pytorch 2.1.1
- Torchvison 0.16.1
- detectron 0.6
- mamba_ssm 1.2.0.post1
- natten 0.15.1
- timm 0.9.12
## Install environment for LGRNet
```
conda create --name lgrnet python=3.10
conda activate lgrnet
# make sure CUDA-12.1 is installed and activated in env var.
# install torch
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
# install detectron2, building may take much time.
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
# install mamba, see https://github.com/state-spaces/mamba
cd ..
git clone https://github.com/state-spaces/mamba.git
cd mamba
pip install . --no-build-isolation
cd ../LGRNet
# install natten, see https://github.com/SHI-Labs/NATTEN/blob/main/docs/install.md
pip install natten==0.15.1+torch210cu121 -f https://shi-labs.com/natten/wheels/cu121/torch2.1.0/natten-0.15.1%2Btorch210cu121-cp310-cp310-linux_x86_64.whl
# misc
pip install albumentations==1.3.1
pip install Pygments
pip install imgaug
pip install timm==0.9.12
# compile deform attention
cd models/encoder/ops/
python setup.py build install --user
# download resnet/pvtv2 ckpt, our model uses the same backbone with WeakPoly(phttps://github.com/weijun88/WeakPolyp)
wget -P $PT_PATH/pvt_v2/pvt_v2_b2.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/pvt_v2_b2.pth
wget -P $PT_PATH/res2net/res2net50_v1b_26w_4s-3cf99910.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/res2net50_v1b_26w_4s-3cf99910.pth
```
================================================
FILE: assets/MODEL_ZOO.md
================================================
================================================
FILE: data_schedule/__init__.py
================================================
import os
if os.getenv('CURRENT_TASK') == 'VIS':
from . import vis
else:
raise ValueError()
def build_schedule(configs, model_input_mapper, model_input_collate_fn):
import logging
from functools import partial
import detectron2.utils.comm as comm
from torch.utils.data import DataLoader, ConcatDataset
from .registry import MAPPER_REGISTRY, EVALUATOR_REGISTRY
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset, MetadataCatalog
from data_schedule.utils.sampler import Evaluate_ExactSampler_Distributed, Train_InfiniteSampler_Distributed
datasets = {'train': [], 'evaluate': []}
meta_idx_shift = 0
for mode in ['train', 'evaluate']:
for dataset_name in configs['data'][mode].keys():
dataset_assume_mode = MetadataCatalog.get(dataset_name).get('mode')
if dataset_assume_mode != mode:
logging.warning(f'default mode of {dataset_name} is {dataset_assume_mode} not {mode}')
dataset_dicts = DatasetFromList(DatasetCatalog.get(dataset_name), copy=True, serialize=True)
mapper = MAPPER_REGISTRY.get(configs['data'][mode][dataset_name]['mapper']['name'])(mode=mode,
dataset_name=dataset_name,
configs=configs,
meta_idx_shift=meta_idx_shift if mode == 'train' else 0)
meta_idx_shift += len(dataset_dicts)
dataset = MapDataset(dataset_dicts, partial(composition, mappers=[mapper,
partial(model_input_mapper, mode=mode)]))
if mode == 'train':
datasets[mode].append(dataset)
else:
datasets[mode].append((dataset_name, dataset))
train_dataset = ConcatDataset(datasets['train'])
logging.debug(f'Total number of training meta: {len(train_dataset)}')
train_loader_splits = configs['optim']['splits']
batch_sizes = configs['optim']['batch_sizes']
splits = list(zip(train_loader_splits[:-1], train_loader_splits[1:]))
assert len(splits) == (len(batch_sizes))
inf_stream_fn = partial(infinite_indices,
seed=configs['stream_idx_seed'],
batch_sizes=configs['optim']['batch_sizes'],
splits=configs['optim']['splits'],
one_batch_two_epoch=configs['optim']['one_batch_two_epoch'],
dataset_length=len(train_dataset),
shuffle=True)
train_samplers = []
train_loaders = []
for btch_size, (range_start, range_end) in zip(batch_sizes, splits):
if range_end is not None:
assert (range_end - range_start) % btch_size == 0, ''
assert btch_size % comm.get_world_size() == 0, ''
each_process_batch_size = int(btch_size / comm.get_world_size())
loader_sampler = Train_InfiniteSampler_Distributed(inf_stream_fn=inf_stream_fn,
start_idx=range_start,
end_idx=range_end,)
train_samplers.append(loader_sampler)
train_loaders.append(DataLoader(train_dataset,
batch_size=each_process_batch_size,
sampler=loader_sampler,
collate_fn=partial(model_input_collate_fn, mode='train'),
num_workers=int(os.getenv('TORCH_NUM_WORKERS')),
pin_memory=True,
persistent_workers=True))
evaluators = []
for eval_dataset_name, eval_dataset in datasets['evaluate']:
logging.debug(f'Number of evaluate meta in {eval_dataset_name}: {len(eval_dataset)}')
loader = DataLoader(eval_dataset,
batch_size=1,
sampler=Evaluate_ExactSampler_Distributed(eval_dataset),
collate_fn=partial(model_input_collate_fn, mode='evaluate'),
num_workers=int(os.getenv('TORCH_NUM_WORKERS')),
pin_memory=True,
persistent_workers=True)
evaluator = EVALUATOR_REGISTRY.get(configs['data']['evaluate'][eval_dataset_name]['evaluator']['name'])(configs=configs,
dataset_name=eval_dataset_name,
data_loader=loader)
evaluators.append((eval_dataset_name, evaluator))
return train_samplers, train_loaders, partial(evaluate_call, evaluators=evaluators)
def composition(data_dict, mappers):
for mappper in mappers:
data_dict = mappper(data_dict)
if data_dict is None:
return None
return data_dict
def evaluate_call(evaluators, model, output_dir):
import detectron2.utils.comm as comm
ret = {}
for eval_dataset_name, evaluator in evaluators:
metric_dict = evaluator(model=model,output_dir=output_dir)
if comm.is_main_process():
for key, value in metric_dict.items():
assert f'{key}_{eval_dataset_name}' not in ret
ret[f'{key}_{eval_dataset_name}'] = value
comm.synchronize()
return ret
def _infinite_indices(seed, dataset_length, shuffle=True,):
import torch
g = torch.Generator()
g.manual_seed(seed)
while True:
if shuffle:
yield from torch.randperm(dataset_length, generator=g).tolist()
else:
yield from torch.arange(dataset_length).tolist()
def infinite_indices(seed,
dataset_length,
batch_sizes,
splits,
one_batch_two_epoch='just_use',
shuffle=True): # 'abandon', 'just_use', 'pad'
import torch
import math
g = torch.Generator()
g.manual_seed(seed)
split_ranges = list(zip(splits[:-1], splits[1:]))
assert len(split_ranges) == (len(batch_sizes))
stream = _infinite_indices(seed, dataset_length=dataset_length, shuffle=shuffle)
stream_throw_cnt = 0
cnt = 0
for (range_start, range_end), btch_size in zip(split_ranges, batch_sizes):
assert cnt == range_start
if range_end == None:
range_end = math.inf
while cnt < range_end:
epoch_milestone = ((stream_throw_cnt // dataset_length) + 1 ) * dataset_length
if (stream_throw_cnt < epoch_milestone) and (stream_throw_cnt + btch_size > epoch_milestone) and (one_batch_two_epoch != 'just_use'):
if one_batch_two_epoch == 'abandon':
for _ in range(epoch_milestone - stream_throw_cnt):
abandon = next(stream)
stream_throw_cnt += 1
elif one_batch_two_epoch == 'pad':
diff = stream_throw_cnt + btch_size - epoch_milestone
num_throw = btch_size - diff
rand_idxs = torch.randperm(dataset_length, generator=g)[:diff].tolist()
for _ in range(num_throw):
cnt += 1
stream_throw_cnt += 1
yield next(stream)
for idx in rand_idxs:
cnt += 1
yield idx
else:
raise ValueError()
else:
for _ in range(btch_size):
cnt += 1
stream_throw_cnt += 1
yield next(stream)
assert cnt == range_end
================================================
FILE: data_schedule/registry.py
================================================
from detectron2.utils.registry import Registry
EVALUATOR_REGISTRY = Registry('EVALUATOR')
MAPPER_REGISTRY = Registry('MAPPER')
class Mapper:
def __init__(self,
meta_idx_shift,
dataset_meta,) -> None:
self.meta_idx_shift = meta_idx_shift
self.visualized_meta_idxs = dataset_meta.get('visualize_meta_idxs')
def _call(self, data_dict):
pass
def __call__(self, data_dict):
meta_idx = data_dict['meta_idx']
ret = self._call(data_dict)
if ret is None:
return None
ret['meta_idx'] = meta_idx + self.meta_idx_shift
if meta_idx in self.visualized_meta_idxs:
ret['visualize'] = True
else:
ret['visualize'] = False
return ret
================================================
FILE: data_schedule/utils/box_ops.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
assert ((x1 - x0) >= 0).all()
assert ((y1 - y0) >= 0).all()
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
================================================
FILE: data_schedule/utils/sampler.py
================================================
import math
import torch.distributed as dist
from typing import TypeVar, Optional, Iterator
T_co = TypeVar('T_co', covariant=True)
from torch.utils.data.distributed import DistributedSampler
import detectron2.utils.comm as comm
from utils.misc import all_gather
from torch.utils.data import Sampler
import torch
import logging
import itertools
class TrainRandomSampler_ByEpoch(Sampler[int]):
def __init__(self,
data_source,
seed,
) -> None:
self.data_source = data_source
self.num_samples = len(self.data_source)
self.seed = seed
self.epoch = None
def __iter__(self):
seed = self.seed + self.epoch
print(f'generating a new indices permutations for this epoch using seed {seed}')
n = len(self.data_source)
g = torch.Generator()
g.manual_seed(seed)
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=g).tolist()
yield from torch.randperm(n, generator=g).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class Train_InfiniteSampler_Distributed(Sampler[T_co]):
def __init__(self,
inf_stream_fn,
start_idx: int = 0,
end_idx = None,
):
self.rank = comm.get_rank()
self.num_replicas = comm.get_world_size()
self.start_idx = start_idx
self.end_idx = end_idx
self.inf_stream_fn = inf_stream_fn
def set_iter_first_sample_idx(self, idx):
self.start_idx = idx
def set_iter_last_sample_idx(self, idx):
self.end_idx = idx
def __iter__(self) -> Iterator[T_co]:
logging.debug(f'在 infinite stream 上定位到{self.start_idx} 为开头')
yield from itertools.islice(self.inf_stream_fn(), self.start_idx + self.rank, self.end_idx, self.num_replicas)
class Evaluate_ExactSampler_Distributed(Sampler[T_co]):
def __init__(self, dataset) -> None:
self.dataset = dataset
self.rank = comm.get_rank()
self.num_replicas = comm.get_world_size()
indices = list(range(len(self.dataset)))
self.indices = indices[self.rank:len(self.dataset):self.num_replicas]
def __iter__(self):
yield from self.indices
def __len__(self):
return len(self.indices)
class TrainRandomSampler_ByEpoch_Distributed(Sampler[T_co]):
def __init__(self,
dataset, num_replicas,
rank,
seed: int = 0) -> None:
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"" [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = None
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
seed = self.seed + self.epoch
logging.debug(f'generating a new indices permutations for this epoch using seed {seed}')
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
self.epoch = None
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class InferenceSampler(Sampler):
"""
Produce indices for inference across all workers.
Inference needs to run on the __exact__ set of samples,
therefore when the total number of samples is not divisible by the number of workers,
this sampler produces different number of samples on different workers.
"""
def __init__(self, size: int):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
"""
self._size = size
assert size > 0
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
================================================
FILE: data_schedule/utils/segmentation.py
================================================
import torch
def bounding_box_from_mask(mask):
if not mask.any():
return torch.zeros([4]).float()
rows = torch.any(mask, dim=1) # h
cols = torch.any(mask, dim=0) # w
row_indexs = torch.where(rows)[0]
rmin, rmax = row_indexs.min(), row_indexs.max()
col_indexs = torch.where(cols)[0]
cmin, cmax = col_indexs.min(), col_indexs.max()
return torch.tensor([cmin, rmin, cmax, rmax]).float() # x1y1x2y2
================================================
FILE: data_schedule/vis/__init__.py
================================================
from . import polyp
from . import mapper
from . import evaluator_fast
from . import vis_aug_eval
from . import vis_aug_train
from . import vis_frame_sampler
================================================
FILE: data_schedule/vis/apis.py
================================================
class VIS_Dataset:
"""
"""
class VIS_Aug_CallbackAPI:
"""
"""
class VIS_Evaluator_OutAPI_EvalFn_API:
"""
"""
class VIS_TrainAPI_clipped_video:
"""
"""
class VIS_EvalAPI_clipped_video_request_ann:
"""
"""
class VIS_FrameSampler_InputOutput_API:
"""
"""
class GetFrames:
"""
"""
================================================
FILE: data_schedule/vis/evaluator_fast.py
================================================
import os
from tqdm import tqdm
from functools import partial
import torch
import detectron2.utils.comm as comm
from utils.misc import to_device
from detectron2.data import MetadataCatalog
from data_schedule.registry import EVALUATOR_REGISTRY
from .evaluator_utils import vis_metric_entrypoint
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann, VIS_Aug_CallbackAPI, VIS_Evaluator_OutAPI_EvalFn_API
from collections import defaultdict
@EVALUATOR_REGISTRY.register()
class VIS_Evaluator_FrameFast:
def __init__(self,
dataset_name,
data_loader,
configs) -> None:
self.dataset_name = dataset_name
self.loader = data_loader
frame_metrics = configs['data']['evaluate'][dataset_name]['evaluator']['frame_metrics']
dataset_meta = MetadataCatalog.get(dataset_name)
self.frame_metric_fns = []
for metric_name, metric_config in frame_metrics:
metric_fn = vis_metric_entrypoint(metric_name)
metric_fn = partial(metric_fn, dataset_meta=dataset_meta, **metric_config)
self.frame_metric_fns.append(metric_fn)
self.eval_meta_keys = dataset_meta.get('eval_meta_keys')
metrics_aggregator = configs['data']['evaluate'][dataset_name]['evaluator']['metrics_aggregator']
self.eval_meta_keys = dataset_meta.get('eval_meta_keys') # { video_id: list[fnames] }
self.metrics_aggregator = partial(vis_metric_entrypoint(metrics_aggregator[0]),
dataset_meta=dataset_meta,
eval_meta_keys=self.eval_meta_keys,
**metrics_aggregator[1])
def visualize_path(self, meta_idxs, visualize, evaluator_path):
return [os.path.join(evaluator_path, f'meta_{meta_idx}') if vis else None for (meta_idx, vis) in zip(meta_idxs, visualize)]
@torch.no_grad()
def __call__(self, model, output_dir):
evaluator_path = os.path.join(output_dir, f'eval_{self.dataset_name}')
os.makedirs(evaluator_path, exist_ok=True)
macs, params = None, None
metrics_by_video_id_frame = defaultdict(dict)
for batch_dict in tqdm(self.loader):
VIS_EvalAPI_clipped_video_request_ann
eval_metas = batch_dict.pop('metas')
request_anns = eval_metas['request_ann'][0] # t, bool tensor
frame_strs = eval_metas['frames'][0] # t', list[str]
video_id = eval_metas['video_id'][0] # str
assert request_anns.int().sum() == len(frame_strs)
callback_fns = eval_metas['callback_fns'][0] # list[fn]
visualize_path = self.visualize_path(meta_idxs=batch_dict['meta_idxs'], visualize=batch_dict['visualize'],
evaluator_path=os.path.join(evaluator_path, 'visualize_model'))
batch_dict['visualize_paths'] = visualize_path
batch_dict = to_device(batch_dict, device=model.device)
VIS_Aug_CallbackAPI
# if macs is None:
# from detectron2.utils.analysis import (
# FlopCountAnalysis,
# )
# flops = FlopCountAnalysis(model, batch_dict, inference_func=lambda model, *inputs: model.sample(*inputs))
# total_flops = flops.total()
# # counts = flops.by_operator()
# logging.debug(f'macs: {total_flops/ (10**9) / len(request_anns)}')
model_outputs = model.sample(batch_dict)
predictions = {
'video': model_outputs['video'][0], # t 3 h w
'pred_masks': [haosen for idx, haosen in enumerate(model_outputs['pred_masks'][0]) if request_anns[idx]], # list[nt h w], t'
'pred_class': [haosen for idx, haosen in enumerate(model_outputs['pred_class'][0]) if request_anns[idx]], # list[nt c], t',
}
if 'pred_boxes' in model_outputs:
predictions.update({'pred_boxes': [haosen for idx, haosen in enumerate(model_outputs['pred_boxes'][0]) if request_anns[idx]]}) # # list[nt 4], t,
for cardib in callback_fns:
predictions = cardib(predictions)
pred_masks = predictions['pred_masks']
pred_class = predictions['pred_class']
assert len(frame_strs) == len(pred_masks)
for idx, (fname, fmk, fclass) in enumerate(zip(frame_strs, pred_masks, pred_class)):
VIS_Evaluator_OutAPI_EvalFn_API
frame_pred = {'masks': fmk, 'classes': fclass.tolist(), 'video_id': video_id, 'frame_name': fname}
if 'pred_boxes' in predictions:
frame_pred.update({'boxes': predictions['pred_boxes'][idx]})
meta_key_metrics = {}
for metric_fn in self.frame_metric_fns:
metric_values = metric_fn(frame_pred=frame_pred, output_dir=evaluator_path)
for key, value in metric_values.items():
assert key not in meta_key_metrics
meta_key_metrics[key] = value
assert fname not in metrics_by_video_id_frame[video_id]
metrics_by_video_id_frame[video_id][fname] = meta_key_metrics
metrics_by_video_id_frame = comm.gather(dict(metrics_by_video_id_frame), dst=0)
eval_metrics = {}
if comm.is_main_process():
metrics_by_video = {}
for video_id in tqdm(self.eval_meta_keys.keys(), desc='gathering different processes'):
video_id_metrics = [haosen[video_id] for haosen in metrics_by_video_id_frame if video_id in haosen]
video_id_frame_names = [list(haosen.keys()) for haosen in video_id_metrics]
merged_video_id_frame_names = [item for sublist in video_id_frame_names for item in sublist]
assert len(set(merged_video_id_frame_names)) == len(merged_video_id_frame_names),''
assert set(merged_video_id_frame_names).issubset(set(self.eval_meta_keys[video_id]))
assert set(self.eval_meta_keys[video_id]).issubset(set(merged_video_id_frame_names))
# perframe metrics frame: predictions
vid_frame_metrics = video_id_metrics[0]
for haosen in video_id_metrics[1:]:
vid_frame_metrics.update(haosen)
metrics_by_video[video_id] = vid_frame_metrics
eval_metrics = self.metrics_aggregator(metrics_by_video)
comm.synchronize()
return eval_metrics
================================================
FILE: data_schedule/vis/evaluator_utils.py
================================================
_vis_metric_entrypoints = {}
def register_vis_metric(fn):
vis_metric_name = fn.__name__
if vis_metric_name in _vis_metric_entrypoints:
raise ValueError(f'vis_metric name {vis_metric_name} has been registered')
_vis_metric_entrypoints[vis_metric_name] = fn
return fn
def vis_metric_entrypoint(vis_metric_name):
try:
return _vis_metric_entrypoints[vis_metric_name]
except KeyError as e:
print(f'vis_metric Name {vis_metric_name} not found')
import numpy as np
_EPS = np.spacing(1)
_TYPE = np.float64
def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
gt = gt > 128
pred = pred / 255
if pred.max() != pred.min():
pred = (pred - pred.min()) / (pred.max() - pred.min())
return pred, gt
class Smeasure(object):
def __init__(self, length, alpha: float = 0.5):
self.sms = []
self.alpha = alpha
def step(self, pred: np.ndarray, gt: np.ndarray, idx):
pred, gt = _prepare_data(pred=pred, gt=gt)
sm = self.cal_sm(pred, gt)
self.sms.append(sm)
def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
y = np.mean(gt)
if y == 0:
sm = 1 - np.mean(pred)
elif y == 1:
sm = np.mean(pred)
else:
sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
sm = max(0, sm)
return sm
def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
fg = pred * gt
bg = (1 - pred) * (1 - gt)
u = np.mean(gt)
object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt)
return object_score
def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
x = np.mean(pred[gt == 1])
sigma_x = np.std(pred[gt == 1], ddof=1)
score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS)
return score
def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
x, y = self.centroid(gt)
part_info = self.divide_with_xy(pred, gt, x, y)
w1, w2, w3, w4 = part_info['weight']
pred1, pred2, pred3, pred4 = part_info['pred']
gt1, gt2, gt3, gt4 = part_info['gt']
score1 = self.ssim(pred1, gt1)
score2 = self.ssim(pred2, gt2)
score3 = self.ssim(pred3, gt3)
score4 = self.ssim(pred4, gt4)
return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
def centroid(self, matrix: np.ndarray) -> tuple:
"""
To ensure consistency with the matlab code, one is added to the centroid coordinate,
so there is no need to use the redundant addition operation when dividing the region later,
because the sequence generated by ``1:X`` in matlab will contain ``X``.
:param matrix: a bool data array
:return: the centroid coordinate
"""
h, w = matrix.shape
area_object = np.count_nonzero(matrix)
if area_object == 0:
x = np.round(w / 2)
y = np.round(h / 2)
else:
# More details can be found at: https://www.yuque.com/lart/blog/gpbigm
y, x = np.argwhere(matrix).mean(axis=0).round()
return int(x) + 1, int(y) + 1
def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
h, w = gt.shape
area = h * w
gt_LT = gt[0:y, 0:x]
gt_RT = gt[0:y, x:w]
gt_LB = gt[y:h, 0:x]
gt_RB = gt[y:h, x:w]
pred_LT = pred[0:y, 0:x]
pred_RT = pred[0:y, x:w]
pred_LB = pred[y:h, 0:x]
pred_RB = pred[y:h, x:w]
w1 = x * y / area
w2 = y * (w - x) / area
w3 = (h - y) * x / area
w4 = 1 - w1 - w2 - w3
return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB),
pred=(pred_LT, pred_RT, pred_LB, pred_RB),
weight=(w1, w2, w3, w4))
def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
h, w = pred.shape
N = h * w
x = np.mean(pred)
y = np.mean(gt)
sigma_x = np.sum((pred - x) ** 2) / (N - 1)
sigma_y = np.sum((gt - y) ** 2) / (N - 1)
sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1)
alpha = 4 * x * y * sigma_xy
beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
if alpha != 0:
score = alpha / (beta + _EPS)
elif alpha == 0 and beta == 0:
score = 1
else:
score = 0
return score
def get_results(self):
sm = np.mean(np.array(self.sms, dtype=_TYPE))
return dict(Smeasure=sm)
import torch
import os
from PIL import Image
@register_vis_metric
def mask_dice_iou(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1)
iou = (inter+1)/(union-inter+1)
return {'dice': dice, 'iou': iou}
@register_vis_metric
def mask_dice_iou_sen_mae_smeasure(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
# tp, tp*2 + fp + fn
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1) # 2*tp / tp + tp + fp + fn
iou = (inter+1)/(union-inter+1) # tp / tp + fp + fn
tp = (pred_mask * gt_mask).sum().float()
fp = (pred_mask.sum() - tp).float()
fn = (gt_mask.sum() - tp).float()
tn = (pred_mask.shape[0] * pred_mask.shape[1] - (tp + fp + fn)).float()
their_dice = tp * 2 / (tp + fp + fn + tp)
their_iou = tp / (tp + fp + fn)
# their_spe = tn / (tn + fp)
their_sen = tp / (tp + fn)
their_mae = (pred_mask.float() - gt_mask.float()).abs().mean()
Np = gt_mask.sum()
Nn = gt_mask.shape[0] * gt_mask.shape[1] - Np
null = Smeasure(length=1, alpha=0.5)
null.step(pred=(pred_mask.float() * 255 ).numpy(), gt=(gt_mask.float() * 255).numpy(), idx=None)
their_smeasure = torch.tensor(null.get_results()['Smeasure']).float()
return {'dice': dice, 'iou': iou,
'their_dice': their_dice,
'their_iou': their_iou,
'their_sen': their_sen,
'their_mae_abs': their_mae,
'their_smeasure': their_smeasure,
'tp': tp, # true positive
'fp': fp, # false positive
'fn': fn, # false negative
'tn': tn, # true negative
'Np': Np, # positive accumulation
'Nn': Nn} # negative accumulation
@register_vis_metric
def web(frame_pred, output_dir, **kwargs):
os.makedirs(os.path.join(output_dir, 'web'), exist_ok=True)
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
scores = torch.tensor(frame_pred['classes']) # nq c
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
mask = Image.fromarray(255 * pred_mask.int().numpy()).convert('L')
save_path = os.path.join(output_dir, 'web', video_id)
os.makedirs(save_path, exist_ok=True)
png_path = os.path.join(save_path, f'{frame_name}.png')
if os.path.exists(png_path):
os.remove(png_path)
mask.save(png_path)
return {}
================================================
FILE: data_schedule/vis/fibroid/__init__.py
================================================
# 注册fibrois数据集
from . import fibroid_dataset
# 注册fibroid评估标准
from . import evals
================================================
FILE: data_schedule/vis/fibroid/evals.py
================================================
from data_schedule.vis.evaluator_utils import register_vis_metric
import os
from glob import glob
from tqdm import tqdm
import shutil
from functools import partial
from PIL import Image
import numpy as np
import torch
import detectron2.utils.comm as comm
import logging
import pycocotools.mask as mask_util
from pycocotools.mask import decode as decode_rle
import data_schedule.vis.fibroid.metrics as metrics
@register_vis_metric
def fibroid_other_medi(model_preds,
dataset_meta,
**kwargs):
assert comm.is_main_process()
iou_by_test_sample = []
dice_by_test_sample = []
preds_by_test_sample = []
gt_by_test_sample = []
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
for pred in model_preds:
video_id = pred['video_id'] # str
frame_name = pred['frame_name'] # list[str], t'
masks = pred['masks']# list[rle], nq
scores = pred['scores'] # nq
max_idx = torch.tensor(scores).argmax()
pred_mask = masks[max_idx] # rle
pred_mask = decode_rle(pred_mask)
pred_mask = torch.as_tensor(pred_mask, dtype=torch.uint8).contiguous() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # 0/1
preds_by_test_sample.append(pred_mask)
gt_by_test_sample.append(gt_mask)
tp, fp, fn, tn = metrics.get_stats(pred_mask[None, None, ...], gt_mask[None, None, ...],
mode='binary')
iou_score = metrics.iou_score(tp, fp, fn, tn, reduction='micro')
dice = metrics.dice(tp, fp, fn, tn, reduction='micro')
iou_by_test_sample.append(iou_score)
dice_by_test_sample.append(dice)
mean_iou = torch.tensor(iou_by_test_sample).mean()
mean_dice = torch.tensor(dice_by_test_sample).mean()
preds_by_test_sample = torch.stack(preds_by_test_sample, dim=0).unsqueeze(1) # N 1 h w
gt_by_test_sample = torch.stack(gt_by_test_sample, dim=0).unsqueeze(1) # N 1 h w
tp, fp, fn, tn = metrics.get_stats(preds_by_test_sample, gt_by_test_sample,
mode='binary')
overall_iou = metrics.iou_score(tp, fp, fn, tn, reduction='micro')
recall = metrics.recall(tp, fp, fn, tn, reduction='micro-imagewise')
precision = metrics.precision(tp, fp, fn, tn, reduction='micro-imagewise')
all_medi = {
'mean_iou': mean_iou,
'dice': mean_dice,
'overall_iou': overall_iou, # J/overallIoU
'recall': recall,
'precision': precision,
'F': 2 * precision * recall / (precision + recall)
}
return all_medi
from collections import defaultdict
# by_vid, by_frame
iou_dict = defaultdict(dict)
@register_vis_metric
def fibroid_mask_dice_iou(frame_pred, dataset_meta, **kwargs):
video_id = frame_pred['video_id']
frame_name = frame_pred['frame_name']
masks = frame_pred['masks'] # nq h w
get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn')
scores = torch.tensor(frame_pred['classes']) # nq c, 保证c是2
foreground_scores = scores[:, :-1].sum(-1) # nq
max_idx = foreground_scores.argmax()
pred_mask = masks[max_idx].int() # h w
gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w
gt_mask = gt_mask[0].int() # h w
inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum()
dice = (2*inter+1)/(union+1)
iou = (inter+1)/(union-inter+1)
iou_dict[video_id][frame_name] = iou
if iou > 0.6:
print(f'video_id: {video_id}, frame: {frame_name}: dice {dice}, iou {iou}')
return {'dice': dice, 'iou': iou}
@register_vis_metric
def fibroid_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs):
# output: eval_metrics
# video: frame_name: metric/ vid_metrics
eval_metrics = {}
# video, frame_name
# perframe metrics
metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]]
for taylor_swift in metric_names:
eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean()
# metrics by each video
mean_iou_by_each_video = {}
mean_dice_by_each_video = {}
for video in eval_meta_keys:
mean_iou_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['iou'] for fname in eval_meta_keys[video]]).mean()
mean_dice_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['dice'] for fname in eval_meta_keys[video]]).mean()
mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1]))
mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1]))
logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}')
logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}')
return eval_metrics
================================================
FILE: data_schedule/vis/fibroid/fibroid_dataset.py
================================================
from typing import Optional, Union
import json
import os
from functools import partial
import numpy as np
import torch
import logging
from tqdm import tqdm
import copy
from detectron2.data import DatasetCatalog, MetadataCatalog
from collections import defaultdict
from data_schedule.vis.apis import VIS_Dataset
from .fibroid_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR,\
SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX, SET_NAME_TO_GT_TYPE
def fibroid_train(step_size, # none / int; 0, 6, 13, 19 ...
split_dataset_name,
video_ids,
video_to_frames):
logging.debug(f'{split_dataset_name} Generating metas...')
metas = []
for vid_id in tqdm(video_ids):
all_frames = sorted(video_to_frames[vid_id])
if step_size is None:
metas.append({
'video_id': vid_id,
'all_frames' : all_frames,
'meta_idx': len(metas),
'all_objs': {1: {'class_label': 0,}} # 语义分割
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': vid_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'all_objs': {1: {'class_label': 0,}},
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
def fibroid_evaluate(eval_video_ids,
split_dataset_name,
step_size,
video_to_frames):
if (step_size is not None) and (step_size > 1):
logging.warning('为什么 evaluate的时候step size大于1呢')
raise ValueError()
metas = []
for video_id in eval_video_ids:
VIS_Dataset
all_frames = sorted(video_to_frames[video_id])
if step_size == None:
metas.append({
'video_id': video_id,
'all_frames': all_frames,
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': video_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
_root = os.getenv('DATASET_PATH')
root = os.path.join(_root, 'uterus_myoma/Dataset')
visualize_meta_idxs = defaultdict(list)
visualize_meta_idxs['fibroid_train_step[6]'] = []
visualize_meta_idxs['fibroid_train'] = []
visualize_meta_idxs['fibroid_train_ste[1]'] = []
visualize_meta_idxs['fibroid_validate'] = []
visualize_meta_idxs['fibroid_validate_step[1]'] = []
visualize_meta_idxs['weakPolyP_fibroid_validate_step[1]'] = []
fibroid_meta = {
'thing_classes': ['rumor', 'not rumor'],
'thing_colors': [(255., 140., 0.), (0., 255., 0.)],
}
for name in SET_NAME:
set_dir = SET_NAME_TO_DIR[name]
set_dir = os.path.join(root, set_dir)
num_videos = SET_NAME_TO_NUM_VIDEOS[name]
video_ids = os.listdir(os.path.join(set_dir, 'Frame'))
assert len(video_ids) == num_videos
video_to_frames = {
vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.png')])\
for vid in video_ids
}
mode = SET_NAME_TO_MODE[name]
prefix = SET_NAME_TO_PREFIX[name]
if mode == 'train':
train_meta = copy.deepcopy(fibroid_meta)
gt_type = SET_NAME_TO_GT_TYPE[name]
train_meta.update({
'mode': 'train',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')),
'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, gt_type),),
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
})
# train
for step_size in [1, 6, None]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
train_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(fibroid_train,
video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames,))
MetadataCatalog.get(split_name).set(**train_meta,
step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
elif mode == 'evaluate':
prefix = SET_NAME_TO_PREFIX[name]
validate_meta = copy.deepcopy(fibroid_meta)
validate_meta.update({
'mode': 'evaluate',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))),
'eval_set_name': SET_NAME_TO_DIR[name],
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
'eval_meta_keys': video_to_frames
})
# validate
for step_size in [1, None,]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
validate_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(fibroid_evaluate,
eval_video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames))
MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
================================================
FILE: data_schedule/vis/fibroid/fibroid_utils.py
================================================
import wandb
import plotly.express as px
import logging
import os
import numpy as np
import torch
import json
from joblib import Parallel, delayed
import multiprocessing
import torch.distributed as dist
import detectron2.utils.comm as comm
import pycocotools.mask as mask_util
from pycocotools.mask import encode, area
from data_schedule.utils.segmentation import bounding_box_from_mask
from data_schedule.utils.video_clips import generate_windows_of_video
from glob import glob
from PIL import Image
def get_frames(frames_path, video_id, frames):
return [Image.open(os.path.join(frames_path, video_id, f'{f}.png'),).convert('RGB') for f in frames]
# t' h w, int, obj_ids ; has_ann t
def get_frames_mask(mask_path, video_id, frames):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
SET_NAME = [
'fibroid_train',
'fibroid_validate',
'weakpolyp_train',
'fibroid_validate_temp7',
'fibroid_train_temp7',
# 'weakpolyp_fibroid_train_temp7',
'fibroid_validate_temp8',
'fibroid_train_temp8',
'weakpolyp_fibroid_train_temp8'
]
SET_NAME_TO_DIR = {
'fibroid_train': 'temp/train',
'fibroid_validate': 'temp/test',
'weakpolyp_train': 'temp/uterus_myoma_WeakPolyP_temp/train',
'fibroid_validate_temp7': 'temp7/test',
'fibroid_train_temp7': 'temp7/train',
'weakpolyp_fibroid_train_temp7': 'temp7/uterus_myoma_WeakPolyP_temp7/train',
'fibroid_validate_temp8': 'temp8/test',
'fibroid_train_temp8': 'temp8/train',
'weakpolyp_fibroid_train_temp8': 'temp8/uterus_myoma_WeakPolyP_temp8/train',
}
SET_NAME_TO_NUM_VIDEOS = {
'fibroid_train': 80,
'fibroid_validate': 20,
'weakpolyp_train': 80,
'fibroid_train_temp7': 85,
'fibroid_validate_temp7': 15,
'weakpolyp_fibroid_train_temp7': 85 ,
'fibroid_train_temp8': 83,
'fibroid_validate_temp8': 17,
'weakpolyp_fibroid_train_temp8': 83
}
SET_NAME_TO_MODE = {
'fibroid_train': 'train',
'fibroid_validate': 'evaluate',
'weakpolyp_train': 'train',
'fibroid_train_temp7': 'train',
'fibroid_validate_temp7': 'evaluate',
'weakpolyp_fibroid_train_temp7': 'train',
'fibroid_train_temp8': 'train',
'fibroid_validate_temp8': 'evaluate',
'weakpolyp_fibroid_train_temp8': 'train'
}
SET_NAME_TO_PREFIX = {
'fibroid_train': 'fibroid_train',
'fibroid_validate': 'fibroid_validate',
'weakpolyp_train': 'weakpolyp_fibroid_train',
'fibroid_train_temp7': 'fibroid_train_temp7',
'fibroid_validate_temp7': 'fibroid_validate_temp7',
'weakpolyp_fibroid_train_temp7': 'weakpolyp_fibroid_train_temp7' ,
'fibroid_train_temp8': 'fibroid_train_temp8',
'fibroid_validate_temp8': 'fibroid_validate_temp8',
'weakpolyp_fibroid_train_temp8': 'weakpolyp_fibroid_train_temp8'
}
SET_NAME_TO_GT_TYPE = {
'fibroid_train': 'GT',
'fibroid_validate': 'GT',
'weakpolyp_train': 'Box',
'fibroid_train_temp7': 'GT',
'fibroid_validate_temp7': 'GT',
'weakpolyp_fibroid_train_temp7': 'Box',
'fibroid_train_temp8': 'GT',
'fibroid_validate_temp8': 'GT',
'weakpolyp_fibroid_train_temp8': 'Box'
}
================================================
FILE: data_schedule/vis/fibroid/metrics.py
================================================
import warnings
from typing import Optional, List, Tuple, Union
import torch
"""Various metrics based on Type I and Type II errors.
References:
https://en.wikipedia.org/wiki/Confusion_matrix
Example:
.. code-block:: python
import segmentation_models_pytorch as smp
# lets assume we have multilabel prediction for 3 classes
output = torch.rand([10, 3, 256, 256])
target = torch.rand([10, 3, 256, 256]).round().long()
# first compute statistics for true positives, false positives, false negative and
# true negative "pixels"
tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
# then compute metrics with required reduction (see metric docs)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
"""
__all__ = [
"get_stats",
"fbeta_score",
"f1_score",
"iou_score",
"accuracy",
"precision",
"recall",
"sensitivity",
"specificity",
"balanced_accuracy",
"positive_predictive_value",
"negative_predictive_value",
"false_negative_rate",
"false_positive_rate",
"false_discovery_rate",
"false_omission_rate",
"positive_likelihood_ratio",
"negative_likelihood_ratio",
]
###################################################################################################
# Statistics computation (true positives, false positives, false negatives, false positives)
###################################################################################################
def get_stats(
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
mode: str,
ignore_index: Optional[int] = None,
threshold: Optional[Union[float, List[float]]] = None,
num_classes: Optional[int] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
"""Compute true positive, false positive, false negative, true negative 'pixels'
for each image and each class.
Args:
output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following
shapes and types depending on the specified ``mode``:
'binary'
shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
'multilabel'
shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
'multiclass'
shape (N, ...) and ``torch.LongTensor``
target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``:
'binary'
shape (N, 1, ...)
'multilabel'
shape (N, C, ...)
'multiclass'
shape (N, ...)
mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'``
ignore_index (Optional[int]): Label to ignore on for metric computation.
**Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None.
threshold (Optional[float, List[float]]): Binarization threshold for
``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
num_classes (Optional[int]): Number of classes, necessary attribute
only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1).
If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or
``255``.
Raises:
ValueError: in case of misconfiguration.
Returns:
Tuple[torch.LongTensor]: true_positive, false_positive, false_negative,
true_negative tensors (N, C) shape each.
"""
if torch.is_floating_point(target):
raise ValueError(f"Target should be one of the integer types, got {target.dtype}.")
if torch.is_floating_point(output) and threshold is None:
raise ValueError(
f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}."
)
if torch.is_floating_point(output) and mode == "multiclass":
raise ValueError(f"For ``multiclass`` mode ``output`` should be one of the integer types, got {output.dtype}.")
if mode not in {"binary", "multiclass", "multilabel"}:
raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.")
if mode == "multiclass" and threshold is not None:
raise ValueError("``threshold`` parameter does not supported for this 'multiclass' mode")
if output.shape != target.shape:
raise ValueError(
"Dimensions should match, but ``output`` shape is not equal to ``target`` "
+ f"shape, {output.shape} != {target.shape}"
)
if mode != "multiclass" and ignore_index is not None:
raise ValueError(f"``ignore_index`` parameter is not supproted for '{mode}' mode")
if mode == "multiclass" and num_classes is None:
raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.")
if ignore_index is not None and 0 <= ignore_index <= num_classes - 1:
raise ValueError(
f"``ignore_index`` should be outside the class values range, but got class values in range "
f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``"
f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``"
f"and relevant class values started from ``0``."
)
if mode == "multiclass":
tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes, ignore_index)
else:
if threshold is not None:
output = torch.where(output >= threshold, 1, 0)
target = torch.where(target >= threshold, 1, 0)
tp, fp, fn, tn = _get_stats_multilabel(output, target)
return tp, fp, fn, tn
@torch.no_grad()
def _get_stats_multiclass(
output: torch.LongTensor,
target: torch.LongTensor,
num_classes: int,
ignore_index: Optional[int],
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
if ignore_index is not None:
ignore = target == ignore_index
output = torch.where(ignore, -1, output)
target = torch.where(ignore, -1, target)
ignore_per_sample = ignore.view(batch_size, -1).sum(1)
tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
for i in range(batch_size):
target_i = target[i]
output_i = output[i]
mask = output_i == target_i
matched = torch.where(mask, target_i, -1)
tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1)
fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
tn = num_elements - tp - fp - fn
if ignore_index is not None:
tn = tn - ignore_per_sample[i]
tp_count[i] = tp.long()
fp_count[i] = fp.long()
fn_count[i] = fn.long()
tn_count[i] = tn.long()
return tp_count, fp_count, fn_count, tn_count
@torch.no_grad()
def _get_stats_multilabel(
output: torch.LongTensor,
target: torch.LongTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
batch_size, num_classes, *dims = target.shape
output = output.view(batch_size, num_classes, -1)
target = target.view(batch_size, num_classes, -1)
tp = (output * target).sum(2)
fp = output.sum(2) - tp
fn = target.sum(2) - tp
tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn)
return tp, fp, fn, tn
###################################################################################################
# Metrics computation
###################################################################################################
def _handle_zero_division(x, zero_division):
nans = torch.isnan(x)
if torch.any(nans) and zero_division == "warn":
warnings.warn("Zero division in metric calculation!")
value = zero_division if zero_division != "warn" else 0
value = torch.tensor(value, dtype=x.dtype).to(x.device)
x = torch.where(nans, value, x)
return x
def _compute_metric(
metric_fn,
tp,
fp,
fn,
tn,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division="warn",
**metric_kwargs,
) -> float:
if class_weights is None and reduction is not None and "weighted" in reduction:
raise ValueError(f"Class weights should be provided for `{reduction}` reduction")
class_weights = class_weights if class_weights is not None else 1.0
class_weights = torch.tensor(class_weights).to(tp.device)
class_weights = class_weights / class_weights.sum()
if reduction == "micro":
tp = tp.sum()
fp = fp.sum()
fn = fn.sum()
tn = tn.sum()
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
elif reduction == "macro":
tp = tp.sum(0)
fp = fp.sum(0)
fn = fn.sum(0)
tn = tn.sum(0)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score * class_weights).mean()
elif reduction == "weighted":
tp = tp.sum(0)
fp = fp.sum(0)
fn = fn.sum(0)
tn = tn.sum(0)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score * class_weights).sum()
elif reduction == "micro-imagewise":
tp = tp.sum(1)
fp = fp.sum(1)
fn = fn.sum(1)
tn = tn.sum(1)
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = score.mean()
elif reduction == "macro-imagewise" or reduction == "weighted-imagewise":
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
score = (score.mean(0) * class_weights).mean()
elif reduction == "none" or reduction is None:
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
score = _handle_zero_division(score, zero_division)
else:
raise ValueError(
"`reduction` should be in [micro, macro, weighted, micro-imagewise,"
+ "macro-imagesize, weighted-imagewise, none, None]"
)
return score
# Logic for metric computation, all metrics are with the same interface
def _fbeta_score(tp, fp, fn, tn, beta=1):
beta_tp = (1 + beta**2) * tp
beta_fn = (beta**2) * fn
score = beta_tp / (beta_tp + beta_fn + fp)
return score
def _iou_score(tp, fp, fn, tn):
return tp / (tp + fp + fn)
def _accuracy(tp, fp, fn, tn):
return (tp + tn) / (tp + fp + fn + tn)
def _sensitivity(tp, fp, fn, tn):
return tp / (tp + fn)
def _specificity(tp, fp, fn, tn):
return tn / (tn + fp)
def _balanced_accuracy(tp, fp, fn, tn):
return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2
def _dice(tp, fp, fn, tn):
return tp * 2 / (tp + fp + fn + tp)
def _positive_predictive_value(tp, fp, fn, tn):
return tp / (tp + fp)
def _negative_predictive_value(tp, fp, fn, tn):
return tn / (tn + fn)
def _false_negative_rate(tp, fp, fn, tn):
return fn / (fn + tp)
def _false_positive_rate(tp, fp, fn, tn):
return fp / (fp + tn)
def _false_discovery_rate(tp, fp, fn, tn):
return 1 - _positive_predictive_value(tp, fp, fn, tn)
def _false_omission_rate(tp, fp, fn, tn):
return 1 - _negative_predictive_value(tp, fp, fn, tn)
def _positive_likelihood_ratio(tp, fp, fn, tn):
return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn)
def _negative_likelihood_ratio(tp, fp, fn, tn):
return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn)
def fbeta_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
beta: float = 1.0,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""F beta score"""
return _compute_metric(
_fbeta_score,
tp,
fp,
fn,
tn,
beta=beta,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def f1_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""F1 score"""
return _compute_metric(
_fbeta_score,
tp,
fp,
fn,
tn,
beta=1.0,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def dice(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""IoU score or Jaccard index""" # noqa
return _compute_metric(
_dice,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def iou_score(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""IoU score or Jaccard index""" # noqa
return _compute_metric(
_iou_score,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def accuracy(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Accuracy"""
return _compute_metric(
_accuracy,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def sensitivity(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Sensitivity, recall, hit rate, or true positive rate (TPR)"""
return _compute_metric(
_sensitivity,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def specificity(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Specificity, selectivity or true negative rate (TNR)"""
return _compute_metric(
_specificity,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def balanced_accuracy(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Balanced accuracy"""
return _compute_metric(
_balanced_accuracy,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def positive_predictive_value(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Precision or positive predictive value (PPV)"""
return _compute_metric(
_positive_predictive_value,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def negative_predictive_value(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Negative predictive value (NPV)"""
return _compute_metric(
_negative_predictive_value,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_negative_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Miss rate or false negative rate (FNR)"""
return _compute_metric(
_false_negative_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_positive_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Fall-out or false positive rate (FPR)"""
return _compute_metric(
_false_positive_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_discovery_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""False discovery rate (FDR)""" # noqa
return _compute_metric(
_false_discovery_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def false_omission_rate(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""False omission rate (FOR)""" # noqa
return _compute_metric(
_false_omission_rate,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def positive_likelihood_ratio(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Positive likelihood ratio (LR+)"""
return _compute_metric(
_positive_likelihood_ratio,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
def negative_likelihood_ratio(
tp: torch.LongTensor,
fp: torch.LongTensor,
fn: torch.LongTensor,
tn: torch.LongTensor,
reduction: Optional[str] = None,
class_weights: Optional[List[float]] = None,
zero_division: Union[str, float] = 1.0,
) -> torch.Tensor:
"""Negative likelihood ratio (LR-)"""
return _compute_metric(
_negative_likelihood_ratio,
tp,
fp,
fn,
tn,
reduction=reduction,
class_weights=class_weights,
zero_division=zero_division,
)
_doc = """
Args:
tp (torch.LongTensor): tensor of shape (N, C), true positive cases
fp (torch.LongTensor): tensor of shape (N, C), false positive cases
fn (torch.LongTensor): tensor of shape (N, C), false negative cases
tn (torch.LongTensor): tensor of shape (N, C), true negative cases
reduction (Optional[str]): Define how to aggregate metric between classes and images:
- 'micro'
Sum true positive, false positive, false negative and true negative pixels over
all images and all classes and then compute score.
- 'macro'
Sum true positive, false positive, false negative and true negative pixels over
all images for each label, then compute score for each label separately and average labels scores.
This does not take label imbalance into account.
- 'weighted'
Sum true positive, false positive, false negative and true negative pixels over
all images for each label, then compute score for each label separately and average
weighted labels scores.
- 'micro-imagewise'
Sum true positive, false positive, false negative and true negative pixels for **each image**,
then compute score for **each image** and average scores over dataset. All images contribute equally
to final score, however takes into accout class imbalance for each image.
- 'macro-imagewise'
Compute score for each image and for each class on that image separately, then compute average score
on each image over labels and average image scores over dataset. Does not take into account label
imbalance on each image.
- 'weighted-imagewise'
Compute score for each image and for each class on that image separately, then compute weighted average
score on each image over labels and average image scores over dataset.
- 'none' or ``None``
Same as ``'macro-imagewise'``, but without any reduction.
For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and
``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``.
Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated,
while postfix ``'imagewise'`` defines how scores between the images will be aggregated.
class_weights (Optional[List[float]]): list of class weights for metric
aggregation, in case of `weighted*` reduction is chosen. Defaults to None.
zero_division (Union[str, float]): Sets the value to return when there is a zero division,
i.e. when all predictions and labels are negative. If set to “warn”, this acts as 0,
but warnings are also raised. Defaults to 1.
Returns:
torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric,
else returns tensor of shape (N, C)
References:
https://en.wikipedia.org/wiki/Confusion_matrix
"""
fbeta_score.__doc__ += _doc
f1_score.__doc__ += _doc
iou_score.__doc__ += _doc
accuracy.__doc__ += _doc
sensitivity.__doc__ += _doc
specificity.__doc__ += _doc
balanced_accuracy.__doc__ += _doc
positive_predictive_value.__doc__ += _doc
negative_predictive_value.__doc__ += _doc
false_negative_rate.__doc__ += _doc
false_positive_rate.__doc__ += _doc
false_discovery_rate.__doc__ += _doc
false_omission_rate.__doc__ += _doc
positive_likelihood_ratio.__doc__ += _doc
negative_likelihood_ratio.__doc__ += _doc
precision = positive_predictive_value
recall = sensitivity
================================================
FILE: data_schedule/vis/mapper.py
================================================
import json
import os
from typing import List
import copy
from functools import partial
import random
import numpy as np
import torch
import logging
from einops import rearrange
from detectron2.data import MetadataCatalog
from data_schedule.registry import MAPPER_REGISTRY
from .mapper_utils import VIS_TrainMapper, VIS_EvalMapper
from .vis_frame_sampler import VIS_FRAMES_SAMPLER_REGISTRY
from data_schedule.vis.apis import VIS_Dataset, VIS_Aug_CallbackAPI,\
VIS_TrainAPI_clipped_video, VIS_EvalAPI_clipped_video_request_ann
@MAPPER_REGISTRY.register()
class VIS_Video_EvalMapper(VIS_EvalMapper):
def __init__(self,
configs,
dataset_name,
mode,
meta_idx_shift,
):
assert mode == 'evaluate'
dataset_meta = MetadataCatalog.get(dataset_name)
assert dataset_meta.get('step_size') == None
mapper_config = configs['data'][mode][dataset_name]['mapper']
super().__init__(meta_idx_shift=meta_idx_shift,
dataset_meta=dataset_meta,
mapper_config=mapper_config)
def _call(self, data_dict):
VIS_Dataset
video_id, all_frames = data_dict['video_id'], data_dict['all_frames']
video_frames = self.get_frames_fn(video_id=video_id, frames=all_frames)
aug_ret = {
'video': video_frames,
'callback_fns': []
}
VIS_Aug_CallbackAPI
aug_ret = self.augmentation(aug_ret)
video = aug_ret.pop('video')
callback_fns = aug_ret.pop('callback_fns')[::-1]
VIS_EvalAPI_clipped_video_request_ann
return {
'video_dict': {'video': video},
'meta': {
'video_id': video_id,
'frames': all_frames,
'request_ann': torch.ones(len(all_frames)).bool(),
'callback_fns': callback_fns
}
}
@MAPPER_REGISTRY.register()
class VIS_Video_or_Step_To_Clip_TrainMapper(VIS_TrainMapper):
def __init__(self,
dataset_name,
configs,
mode,
meta_idx_shift,
):
assert mode == 'train'
dataset_meta = MetadataCatalog.get(dataset_name)
assert dataset_meta.get('name') == dataset_name
mapper_config = configs['data'][mode][dataset_name]['mapper']
super().__init__(meta_idx_shift=meta_idx_shift,
dataset_meta=dataset_meta,
mapper_config=mapper_config)
self.frames_sampler = VIS_FRAMES_SAMPLER_REGISTRY.get(\
mapper_config['frames_sampler']['name'])(sampler_configs=mapper_config['frames_sampler'],
dataset_meta=dataset_meta)
def _call(self, data_dict):
VIS_Dataset
video_id, all_frames, all_objs = data_dict['video_id'], data_dict['all_frames'], data_dict['all_objs']
frame_idx = data_dict['frame_idx'] if 'frame_idx' in data_dict else None
all_obj_ids = list(all_objs.keys()) # [1, 2, 5, 4]
assert len(list(set(all_obj_ids))) == len(all_obj_ids)
class_labels = torch.tensor([all_objs[key]['class_label'] for key in all_obj_ids]) # [8, 10, 20 34]
re_sample = True
sampled_counts = 0
while re_sample:
sampled_frames = self.frames_sampler(all_frames=all_frames, frame_idx=frame_idx, video_id=video_id)
# t' h w, int, obj_ids ; has_ann t
frames_mask, has_ann = self.get_frames_mask_fn(video_id=video_id, frames=sampled_frames)
appear_objs = frames_mask.unique() # [0, 1, 2]
assert set(appear_objs.tolist()).issubset(set([0] + all_obj_ids))
re_sample = (len(list(set(appear_objs.tolist()) & set(all_obj_ids))) == 0)
# 只要出现某些个物体就行
sampled_counts += 1
if sampled_counts > 2:
logging.error('sampled two much times')
raise RuntimeError()
frames_mask = torch.stack([frames_mask == obj_id for obj_id in all_obj_ids], dim=0) # N t' h w, bool
video_frames = self.get_frames_fn(video_id=video_id, frames=sampled_frames)
width, height = video_frames[0].size
aug_ret = {
'video': video_frames,
'masks': frames_mask, # N t' h w
'has_ann': has_ann, # t
'classes': class_labels, # N
}
VIS_Aug_CallbackAPI
aug_ret = self.augmentation(aug_ret)
video = aug_ret.pop('video')
frame_targets = self.map_to_frame_targets(aug_ret)
if self.clip_global_targets_map_to_local_targets:
aug_ret = self.map_global_targets_to_local_targets(aug_ret)
VIS_TrainAPI_clipped_video
ret = {}
ret['video_dict'] = {'video': video}
ret['targets'] = aug_ret
ret['frame_targets'] = frame_targets
return ret
================================================
FILE: data_schedule/vis/mapper_utils.py
================================================
from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY, VIS_TRAIN_AUG_REGISTRY
import torch
from copy import deepcopy as dcopy
from data_schedule.registry import Mapper
import copy
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video
class VIS_Mapper(Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,) -> None:
super().__init__(meta_idx_shift=meta_idx_shift, dataset_meta=dataset_meta)
self.get_frames_fn = dataset_meta.get('get_frames_fn')
class VIS_TrainMapper(VIS_Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,
mapper_config) -> None:
super().__init__(meta_idx_shift, dataset_meta)
self.get_frames_mask_fn = dataset_meta.get('get_frames_mask_fn')
self.clip_global_targets_map_to_local_targets = mapper_config['clip_global_targets_map_to_local_targets']
self.augmentation = VIS_TRAIN_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation'])
def map_to_frame_targets(self, clip_targets):
VIS_TrainAPI_clipped_video
clip_rets = copy.deepcopy(clip_targets)
masks = clip_rets['masks'].transpose(0, 1).contiguous() # t' N h w
class_labels = clip_rets['classes'] # [10, 32, 10, 4]
has_box = 'boxes' in clip_rets
if has_box:
boxes = clip_rets['boxes'].transpose(0, 1).contiguous() # t' N 4
assert len(masks) == len(boxes)
ret = []
for idx, frame_mk in enumerate(masks):
frame_targets = {
'masks': frame_mk.unsqueeze(1), # N 1 h w
'classes': class_labels, # N
}
if has_box:
frame_targets.update({'boxes': boxes[idx].unsqueeze(1)}) # N 1 4
if self.clip_global_targets_map_to_local_targets:
frame_targets = self.map_global_targets_to_local_targets(frame_targets)
frame_targets['masks'] = frame_targets['masks'].squeeze(1)
if has_box:
frame_targets['boxes'] = frame_targets['boxes'].squeeze(1)
ret.append(frame_targets)
return ret
def map_global_targets_to_local_targets(self, ret):
VIS_TrainAPI_clipped_video
masks = ret['masks'] # N t' h w
global_obj_appear = masks.flatten(1).any(-1) # N [True, False, True, False, False, False, True]
ret['masks'] = ret['masks'][global_obj_appear]
ret['classes'] = ret['classes'][global_obj_appear]
if 'boxes' in ret:
ret['boxes'] = ret['boxes'][global_obj_appear] # n t' 4
return ret
class VIS_EvalMapper(VIS_Mapper):
def __init__(self,
meta_idx_shift,
dataset_meta,
mapper_config) -> None:
super().__init__(meta_idx_shift, dataset_meta)
assert mapper_config['augmentation']['name'] in ['WeakPolyP_EvalAug', 'Visha_EvalAug']
self.augmentation = VIS_EVAL_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation'])
================================================
FILE: data_schedule/vis/polyp/__init__.py
================================================
from . import polyp_dataset
from . import evals
================================================
FILE: data_schedule/vis/polyp/evals.py
================================================
from data_schedule.vis.evaluator_utils import register_vis_metric
import os
import torch
import detectron2.utils.comm as comm
import logging
import subprocess
@register_vis_metric
def polyp_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs):
# output: eval_metrics
# video: frame_name: metric/ vid_metrics
eval_metrics = {}
# video, frame_name
# perframe metrics
metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]]
for taylor_swift in metric_names:
eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean()
# metrics by each video
mean_iou_by_each_video = {}
mean_dice_by_each_video = {}
for billie_eilish in eval_meta_keys:
mean_iou_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['iou'] for fname in eval_meta_keys[billie_eilish]]).mean()
mean_dice_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['dice'] for fname in eval_meta_keys[billie_eilish]]).mean()
mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1]))
mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1]))
logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}')
logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}')
return eval_metrics
================================================
FILE: data_schedule/vis/polyp/polyp_dataset.py
================================================
from typing import Optional, Union
import json
import os
from functools import partial
import numpy as np
import torch
import logging
from tqdm import tqdm
import copy
from detectron2.data import DatasetCatalog, MetadataCatalog
from collections import defaultdict
from .polyp_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR, SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX
def polyp_train(step_size,
split_dataset_name,
video_ids,
video_to_frames,
root_path):
logging.debug(f'{split_dataset_name} Generating metas...')
metas = []
for vid_id in tqdm(video_ids):
all_frames = sorted(video_to_frames[vid_id])
poly_class = 0
if step_size is None:
metas.append({
'video_id': vid_id,
'all_frames' : all_frames,
'all_objs': { 1: {'class_label': poly_class} },
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': vid_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'all_objs': { 1: {'class_label': poly_class} },
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
def polyp_evaluate(eval_video_ids,
split_dataset_name,
step_size,
video_to_frames):
if (step_size is not None) and (step_size > 1):
logging.warning('why?')
raise ValueError()
metas = []
for video_id in eval_video_ids:
all_frames = sorted(video_to_frames[video_id])
if step_size == None:
metas.append({
'video_id': video_id,
'all_frames': all_frames,
'meta_idx': len(metas)
})
else:
for frame_idx in range(0, len(all_frames), step_size):
metas.append({
'video_id': video_id,
'frame_idx': frame_idx,
'all_frames': all_frames,
'meta_idx': len(metas)
})
logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]')
return metas
_root = os.getenv('DATASET_PATH')
root = os.path.join(_root, 'SUN/SUN-SEG2')
visualize_meta_idxs = defaultdict(list)
visualize_meta_idxs['polyp_train_step[6]'] = []
visualize_meta_idxs['polyp_train'] = []
visualize_meta_idxs['polyp_hard_unseen'] = []
visualize_meta_idxs['polyp_hard_seen'] = []
visualize_meta_idxs['polyp_easy_unseen'] = []
visualize_meta_idxs['polyp_easy_seen'] = []
polyp_meta = {
'thing_classes': ['polyp', 'not polyp'],
'thing_colors': [(255., 140., 0.), (0., 255., 0.)],
'root': root
}
for name in SET_NAME:
set_dir = SET_NAME_TO_DIR[name]
set_dir = os.path.join(root, set_dir)
num_videos = SET_NAME_TO_NUM_VIDEOS[name]
video_ids = os.listdir(os.path.join(set_dir, 'Frame'))
assert len(video_ids) == num_videos
video_to_frames = {
vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.jpg')])\
for vid in video_ids
}
mode = SET_NAME_TO_MODE[name]
if mode == 'train':
prefix = SET_NAME_TO_PREFIX[name]
train_meta = copy.deepcopy(polyp_meta)
train_meta.update({
'mode': 'train',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')),
'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, 'GT'),),
})
# train
for step_size in [1, 3, 6, 9, 12, None]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
train_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(polyp_train,
video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames,
root_path=set_dir))
MetadataCatalog.get(split_name).set(**train_meta,
step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
elif mode == 'evaluate':
prefix = SET_NAME_TO_PREFIX[name]
validate_meta = copy.deepcopy(polyp_meta)
validate_meta.update({
'mode': 'evaluate',
'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))),
'eval_set_name': SET_NAME_TO_DIR[name],
'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),),
'eval_meta_keys': video_to_frames
})
for step_size in [1, None,]:
step_identifer = '' if step_size is None else f'_step[{step_size}]'
split_name = f'{prefix}{step_identifer}'
validate_meta.update({'name': split_name})
DatasetCatalog.register(split_name, partial(polyp_evaluate,
eval_video_ids=video_ids,
split_dataset_name=split_name,
step_size=step_size,
video_to_frames=video_to_frames))
MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size,
visualize_meta_idxs=visualize_meta_idxs[split_name])
================================================
FILE: data_schedule/vis/polyp/polyp_utils.py
================================================
import os
import numpy as np
import torch
from PIL import Image
SET_NAME = ['polyp_train',
'polyp_hard_seen_validate',
'polyp_hard_unseen_validate',
'polyp_easy_seen_validate',
'polyp_easy_unseen_validate',
'polyp_hard_validate',
'polyp_easy_validate',
'Kvasir-train',
'Mayo-train',
'300-train',
'612-train',
'300-tv',
'612-test',
'612-val'
]
SET_NAME_TO_DIR = {
'polyp_train': 'TrainDataset',
'polyp_hard_seen_validate': 'TestHardDataset/Seen',
'polyp_hard_unseen_validate': 'TestHardDataset/Unseen',
'polyp_easy_seen_validate': 'TestEasyDataset/Seen',
'polyp_easy_unseen_validate': 'TestEasyDataset/Unseen',
'polyp_hard_validate': 'TestHardDataset/Combine',
'polyp_easy_validate': 'TestEasyDataset/Combine',
'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG',
'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
'300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
'612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train',
'300-tv': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ColonDB-300',
'612-test': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Test',
'612-val': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Valid'
}
SET_NAME_TO_NUM_VIDEOS = {
'polyp_train': 112,
'polyp_hard_seen_validate': 17,
'polyp_hard_unseen_validate': 37,
'polyp_easy_seen_validate': 33,
'polyp_easy_unseen_validate': 86,
'polyp_hard_validate': 54,
'polyp_easy_validate': 119,
'Kvasir-train': 1,
'Mayo-train': 10,
'300-train': 6,
'612-train': 18,
'300-tv': 6,
'612-test': 5,
'612-val': 5
}
SET_NAME_TO_MODE = {
'polyp_train': 'train',
'polyp_hard_seen_validate': 'evaluate',
'polyp_hard_unseen_validate': 'evaluate',
'polyp_easy_seen_validate': 'evaluate',
'polyp_easy_unseen_validate': 'evaluate',
'polyp_hard_validate': 'evaluate',
'polyp_easy_validate': 'evaluate',
'Kvasir-train': 'train',
'Mayo-train': 'train',
'300-train': 'train',
'612-train': 'train',
'300-tv': 'evaluate',
'612-test': 'evaluate',
'612-val': 'evaluate'
}
SET_NAME_TO_PREFIX = {
'polyp_train': 'polyp_train',
'polyp_hard_seen_validate': 'polyp_hard_seen_validate',
'polyp_hard_unseen_validate': 'polyp_hard_unseen_validate',
'polyp_easy_seen_validate': 'polyp_easy_seen_validate',
'polyp_easy_unseen_validate': 'polyp_easy_unseen_validate',
'polyp_hard_validate': 'polyp_hard_validate',
'polyp_easy_validate': 'polyp_easy_validate',
'Kvasir-train': 'Kvasir-train',
'Mayo-train': 'Mayo-train',
'300-train': '300-train',
'612-train': '612-train',
'300-tv': '300-tv',
'612-test': '612-test',
'612-val': '612-val'
}
CLASS_TO_ID = {
'high_grade_adenoma':0,
'hyperplastic_polyp':1,
'invasive_cancer':2,
'low_grade_adenoma':3,
'sessile_serrated_lesion':4,
'traditional_serrated_adenoma':5
}
def get_frames(frames_path, video_id, frames):
return [Image.open(os.path.join(frames_path, video_id, f'{f}.jpg')).convert('RGB') for f in frames]
def get_frames_mask(mask_path, video_id, frames):
# masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames]
if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames]
else:
raise ValueError()
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
# assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}'
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
================================================
FILE: data_schedule/vis/vis_aug_eval.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
import torch
import torchvision.transforms.functional as F
from data_schedule.vis.apis import VIS_Aug_CallbackAPI
from .vis_aug_utils import get_tgt_size
from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY
class RandomResize:
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, ret):
video = ret['video']
orig_size = video[0].size # w h
tgt_size = get_tgt_size(video[0].size, random.choice(self.sizes), self.max_size) # h w
resized_video = [F.resize(frame, tgt_size) for frame in video]
ratio_width, ratio_height = tuple(float(s) / float(s_orig) for s, s_orig in zip(tgt_size[::-1], orig_size))
ret['video'] = resized_video
if 'callback_fns' in ret:
VIS_Aug_CallbackAPI
ret['callback_fns'].append(RandomResize(sizes=[orig_size], max_size=None))
if "pred_masks" in ret:
assert (len(self.sizes) == 1) and (self.max_size == None)
VIS_Aug_CallbackAPI
pred_masks = ret['pred_masks'] # list[nt h w], t
pred_masks = [torch.nn.functional.interpolate(mk.unsqueeze(0).float(), tgt_size, mode='nearest')[0].bool()
for mk in pred_masks]
ret['pred_masks'] = pred_masks # list[nt h w], t
if "pred_boxes" in ret:
VIS_Aug_CallbackAPI
pred_boxes = ret["pred_boxes"] # list[nt 4], t
scaled_boxes = [bx * (torch.tensor([ratio_width, ratio_height, ratio_width, ratio_height])[None, :])
for bx in pred_boxes]
ret["pred_boxes"] = scaled_boxes
return ret
class VideoToPIL:
def __call__(self, ret):
video = ret['video'] # t 3 h w ->
assert video.dtype == torch.float and (video.max() <= 1) and (video.min() >=0)
pil_video = [F.to_pil_image(frame, mode='RGB') for frame in video] # 3 h w, float, 0-1
ret['video'] = pil_video
assert 'callback_fns' not in ret
return ret
class VideoToTensor:
def __call__(self, ret):
video = ret['video']
tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1
ret['video'] = tensor_video
if 'callback_fns' in ret:
VIS_Aug_CallbackAPI
ret['callback_fns'].append(VideoToPIL())
return ret
@VIS_EVAL_AUG_REGISTRY.register()
class WeakPolyP_EvalAug:
def __init__(self, configs) -> None:
self.resize = RandomResize(
sizes=[[352, 352]],
)
self.tensor_video = VideoToTensor()
def __call__(self, ret):
VIS_Aug_CallbackAPI
ret = self.resize(ret)
ret = self.tensor_video(ret)
return ret
================================================
FILE: data_schedule/vis/vis_aug_train.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
from PIL import Image
import torch
import torchvision.transforms.functional as F
from einops import rearrange
from copy import deepcopy as dcopy
from data_schedule.vis.apis import VIS_Aug_CallbackAPI
import albumentations as A
import numpy as np
from data_schedule.utils.segmentation import bounding_box_from_mask
from .vis_aug_utils import VIS_TRAIN_AUG_REGISTRY, pil_torch_to_numpy, numpy_to_pil_torch
import copy
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
import imgaug
from datetime import datetime
class RandomRotate90:
def __init__(self) -> None:
self.album_aug = A.ReplayCompose(
[A.RandomRotate90(0.5)]
)
def __call__(self, ret):
video = ret['video']
masks = ret['masks']
has_ann = ret['has_ann']
# list[PIL], n t' h w ->
# list[h w 3, 255rgb], t
# list[list[h w, 01uint8]] t
video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann)
replay = self.album_aug(image=video[0], mask=[masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(video, masks):
ret = self.album_aug.replay(replay, image=vid, mask=mk)
auged_video.append(ret['image'])
auged_mask.append(ret['mask'])
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, auged_mask=auged_mask, has_ann=has_ann)
ret['video'] = auged_video
ret['mask'] = auged_mask
return ret
class ComputeBox:
def __call__(self, ret):
W, H = ret['video'][0].size
N, T = ret['masks'].shape[:2] # n t' h w
boxes = torch.stack([bounding_box_from_mask(mask) for mask in copy.deepcopy(ret['masks']).flatten(0, 1)], dim=0) # Nt' 4
boxes = rearrange(boxes, '(N T) c -> N T c', N=N, T=T)
boxes[:, :, 0::2].clamp_(min=0, max=W)
boxes[:, :, 1::2].clamp_(min=0, max=H)
ret['boxes'] = boxes
return ret
class VideoToTensor:
def __call__(self, ret):
video = ret['video']
tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1
ret['video'] = tensor_video
return ret
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, ret):
for t in self.transforms:
ret = t(ret)
return ret
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
@VIS_TRAIN_AUG_REGISTRY.register()
class WeakPolyP_TrainAug:
def __init__(self, configs) -> None:
self.transform = A.ReplayCompose([
A.Resize(352, 352),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
])
self.tensor_video = VideoToTensor()
self.add_box = ComputeBox()
def __call__(self, ret):
VIS_Aug_CallbackAPI
video = ret['video']
masks = ret['masks'] # n t' h w
has_ann = ret['has_ann'] # t
# list[PIL] -> list[h w 3, 0-1float], t
# n t' h w -> list[list[h w, 01uint8], 没有annotation的帧box是空] t
video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann)
replay = self.transform(image=video[0], masks=[masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(video, masks):
auged_each_frame = self.transform.replay(replay, image=vid, masks=mk)
auged_video.append(auged_each_frame['image'])
auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8]
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann)
ret['video'] = auged_video
ret['masks'] = auged_mask
ret = self.add_box(ret)
ret = self.tensor_video(ret)
return ret
@VIS_TRAIN_AUG_REGISTRY.register()
class WeakPolyP_TrainAug_RotateImageToClip:
def __init__(self, configs) -> None:
self.ImageToSeqAugmenter = ImageToSeqAugmenter(perspective=True, affine=True, motion_blur=True,
rotation_range=(-20, 20), perspective_magnitude=0.08,
hue_saturation_range=(-5, 5), brightness_range=(-40, 40),
motion_blur_prob=0.25, motion_blur_kernel_sizes=(9, 11),
translate_range=(-0.1, 0.1))
self.num_frames = configs['num_frames']
self.transform = A.ReplayCompose([
A.Resize(352, 352),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
])
self.tensor_video = VideoToTensor()
self.add_box = ComputeBox()
def apply_random_sequence_shuffle(self, images, instance_masks):
perm = list(range(self.num_frames))
random.shuffle(perm)
images = [images[i] for i in perm]
instance_masks = [instance_masks[i] for i in perm]
return images, instance_masks
def __call__(self, ret):
VIS_Aug_CallbackAPI
video = ret['video'] # list[pil], t
masks = ret['masks'] # n t' h w
has_ann = ret['has_ann'] # t
# list[PIL] -> list[h w 3, uint8], t
# n t' h w -> list[list[h w], n, uint8], t
seq_images, seq_instance_masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann, float_image=False)
assert len(seq_images) == 1 and len(seq_instance_masks) == 1
static_img, static_mask = seq_images[0], seq_instance_masks[0]
for t in range(self.num_frames - 1):
im_trafo, instance_masks_trafo = self.ImageToSeqAugmenter(static_img, static_mask) # h w 3, uint8; list[h w], n, uint8
seq_images.append(np.uint8(im_trafo))
seq_instance_masks.append(instance_masks_trafo)
# list[h w 3], t ; # list[list[h w, 01uint8]] t
seq_images, seq_instance_masks = self.apply_random_sequence_shuffle(seq_images, seq_instance_masks)
has_ann = torch.ones(self.num_frames).bool() # T
seq_images = [np.float32(haosen) / 255.0 for haosen in seq_images] # list[h w 3, 0-1float], t
replay = self.transform(image=seq_images[0], masks=[seq_instance_masks[0][0]])['replay']
auged_video = []
auged_mask = []
for vid, mk in zip(seq_images, seq_instance_masks):
auged_each_frame = self.transform.replay(replay, image=vid, masks=mk)
auged_video.append(auged_each_frame['image'])
auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8]
auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann) # n t h w
# [haosen.save(f'./test{idx}.png') for idx, haosen in enumerate(auged_video)]
# import matplotlib.pyplot as plt
# [plt.imsave( f'./mask{idx}.png', auged_mask[0][idx].float().numpy()) for idx in range(len(auged_mask[0]))]
ret['video'] = auged_video
ret['masks'] = auged_mask
ret['has_ann'] = has_ann
ret = self.add_box(ret)
ret = self.tensor_video(ret)
return ret
class ImageToSeqAugmenter(object):
def __init__(self, perspective=True, affine=True, motion_blur=True,
brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12,
scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20),
motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5, seed=2024):
self.basic_augmenter = iaa.SomeOf((1, None), [
iaa.Add(brightness_range),
iaa.AddToHueAndSaturation(hue_saturation_range)
]
)
transforms = []
if perspective:
transforms.append(iaa.PerspectiveTransform(perspective_magnitude))
if affine:
transforms.append(iaa.Affine(scale=scale_range,
translate_percent=translate_range,
rotate=rotation_range,
order=1, # cv2.INTER_LINEAR
backend='auto'))
transforms = iaa.Sequential(transforms)
transforms = [transforms]
if motion_blur:
blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf(
[
iaa.MotionBlur(ksize)
for ksize in motion_blur_kernel_sizes
]
))
transforms.append(blur)
self.frame_shift_augmenter = iaa.Sequential(transforms)
self.seed = seed
@staticmethod
def condense_masks(instance_masks):
condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8)
for instance_id, mask in enumerate(instance_masks, 1):
condensed_mask = np.where(mask, instance_id, condensed_mask)
return condensed_mask
@staticmethod
def expand_masks(condensed_mask, num_instances):
return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)]
def __call__(self, image, masks=None, boxes=None): # n h w
det_augmenter = self.frame_shift_augmenter.to_deterministic()
if masks is not None:
masks_np, is_binary_mask = [], []
boxs_np = []
for mask in masks:
if isinstance(mask, np.ndarray):
masks_np.append(mask.astype(np.bool_))
is_binary_mask.append(False)
else:
raise ValueError("Invalid mask type: {}".format(type(mask)))
num_instances = len(masks_np)
masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2])
# boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2])
seed = int(datetime.now().strftime('%M%S%f')[-8:])
imgaug.seed(seed)
aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np)
imgaug.seed(seed)
invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2)
aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances)
# aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image()
aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)]
return aug_image, aug_masks #, aug_boxes.to_xyxy_array()
else:
masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])]
aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks)
return aug_image, invalid_pts_mask.get_arr() == 0
================================================
FILE: data_schedule/vis/vis_aug_utils.py
================================================
from detectron2.utils.registry import Registry
import torch
import numpy as np
import torchvision.transforms.functional as F
from PIL import Image
from einops import rearrange, reduce, repeat
VIS_EVAL_AUG_REGISTRY = Registry('VIS_EVAL_AUG')
VIS_TRAIN_AUG_REGISTRY = Registry('VIS_TRAIN_AUG')
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_tgt_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
def pil_torch_to_numpy(video, masks, has_ann, float_image=True):
# n t' h w
# list[pil_image, rgb], t
# t
N, T = masks.shape[:2]
has_ann_idx = torch.where(has_ann)[0] # time_idx
# list[Image], t -> list[h w 3, 255uint8], t
masks = masks.permute(1, 0, 2, 3).contiguous().unbind(0) # list[n h w] t'
numpy_masks = [[]] * len(has_ann) # list[list[h w, 01_uint8], n], t
assert len(has_ann_idx) == len(masks)
for fmask, taylor in zip(masks, has_ann_idx): # n h w
fnumpy_masks = []
for mk in fmask.unbind(0): # h w
fnumpy_masks.append(mk.numpy().astype(np.uint8))
numpy_masks[taylor] = fnumpy_masks
if float_image:
# list[h w 3, 0-1float], t
video = [F.to_tensor(frame).permute(1,2,0).numpy() for frame in video]
else:
# uint8
video = [np.array(frame) for frame in video]
return video, numpy_masks
def numpy_to_pil_torch(video, masks, has_ann):
# numpy, numpy -> torch, torch
# list[h w 3, 0-1float], t
H, W = video[0].shape[:2]
T = has_ann.int().sum()
video = [Image.fromarray(np.uint8(aug_vid * 255), mode="RGB") for aug_vid in video]
# t'n h w
torch_masks = torch.stack([torch.from_numpy(obj_mk).bool() for frame_mk in masks for obj_mk in frame_mk], dim=0)
torch_masks = rearrange(torch_masks, '(T N) h w -> N T h w', T=T) # n t' h w
return video, torch_masks
================================================
FILE: data_schedule/vis/vis_frame_sampler.py
================================================
from detectron2.utils.registry import Registry
import random
import numpy as np
import torch
import logging
from detectron2.utils import comm
VIS_FRAMES_SAMPLER_REGISTRY = Registry('VIS_FRAMES_SAMPLER')
import random
@VIS_FRAMES_SAMPLER_REGISTRY.register()
class Naive_ReferenceFrame_FrameSampler:
def __init__(self, sampler_configs, dataset_meta, **kwargs):
self.reference_frame_step_size = dataset_meta.get('step_size')
self.clip_sizes = list(sampler_configs['clip_sizes']) # list[int]
self.clip_distribute = sampler_configs['clip_distribute'] # dense, sparse, local_global
self.clip_position = sampler_configs['clip_position'] # former, center, latter
if max(self.clip_sizes) > self.reference_frame_step_size:
if comm.is_main_process():
logging.warning('')
def __call__(self,
frame_idx=None,
all_frames=None, # list[str]
**kwargs):
random_clip_size = random.choice(self.clip_sizes)
video_len = len(all_frames)
sample_indx = [frame_idx]
if (self.clip_position == 'center') and (self.clip_distribute == 'local_global'):
if random_clip_size != 1:
sample_id_before = random.randint(1, 3)
sample_id_after = random.randint(1, 3)
local_indx = [max(0, frame_idx - sample_id_before), min(video_len - 1, frame_idx + sample_id_after)]
sample_indx.extend(local_indx)
if random_clip_size > 3:
all_inds = list(range(video_len))
global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):]
global_n = random_clip_size - len(sample_indx)
if len(global_inds) > global_n:
select_id = random.sample(range(len(global_inds)), global_n)
for s_id in select_id:
sample_indx.append(global_inds[s_id])
elif video_len >= global_n:
select_id = random.sample(range(video_len), global_n)
for s_id in select_id:
sample_indx.append(all_inds[s_id])
else:
select_id = random.sample(range(video_len), global_n - video_len) + list(range(video_len))
for s_id in select_id:
sample_indx.append(all_inds[s_id])
elif (self.clip_position == 'center') and (self.clip_distribute == 'dense'):
half_size = (random_clip_size - 1) // 2
sample_indx += list(range(frame_idx - half_size, frame_idx))
sample_indx += list(range(frame_idx+1, half_size + frame_idx + 1))
if len(sample_indx) < random_clip_size:
sample_indx = [min(sample_indx)] + sample_indx
assert len(sample_indx) == random_clip_size
sample_indx = torch.tensor(sample_indx)
sample_indx = sample_indx.clamp_(min=0, max=video_len-1)
sample_indx = sample_indx.tolist()
else:
raise ValueError()
sample_indx.sort()
sampled_frames = [all_frames[idx] for idx in sample_indx]
return sampled_frames
================================================
FILE: handle_vps.py
================================================
import cv2
import numpy as np
import os
import shutil
from PIL import Image
import torch
from tqdm import tqdm
dataset_root = os.getenv('DATASET_PATH')
# the original IVPS is the union of Kvasir and per-frame Mayo/CVC
all_images = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame')
ka_images = [b for b in all_images if b.startswith('K')]
assert len(ka_images) == 1000
all_gts = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT')
ka_gts = [b for b in all_gts if b.startswith('K')]
assert len(ka_gts) == 1000
os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1'),exist_ok=True)
os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1'),exist_ok=True)
for image_id in tqdm(ka_images):
shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame', f'{image_id}'),
os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1', f'{image_id}'),)
for image_id in tqdm(ka_gts):
shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT', f'{image_id}'),
os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1', f'{image_id}'),)
# normalize train directory
for base_path in [f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train']:
video_ids = os.listdir(base_path)
frame_path = os.path.join(base_path, 'Frame')
gt_path = os.path.join(base_path, 'GT')
os.makedirs(frame_path, exist_ok=True)
os.makedirs(gt_path, exist_ok=True)
# Iterate through each video ID directory
for vid in video_ids:
shutil.copytree(os.path.join(base_path, vid, 'Frame'), os.path.join(frame_path, vid))
shutil.copytree(os.path.join(base_path, vid, 'GT'), os.path.join(gt_path, vid))
# TODO: dangerous: remove if you want
# remove non-mask frames of each training set
SET_NAME = [
'Kvasir-train',
'Mayo-train',
'300-train',
'612-train',
]
SET_NAME_TO_DIR = {
'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG',
'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train',
'300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train',
'612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train',
}
SET_NAME_TO_NUM_VIDEOS = {
'Kvasir-train': 1,
'Mayo-train': 10,
'300-train': 6,
'612-train': 18,
'300-tv': 6,
'612-test': 5,
'612-val': 5
}
SET_NAME_TO_PREFIX = {
'Kvasir-train': 'Kvasir-train',
'Mayo-train': 'Mayo-train',
'300-train': '300-train',
'612-train': '612-train',
}
root = os.getenv('DATASET_PATH')
def get_frames_mask(mask_path, video_id, frames):
# masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames]
if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames]
elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')):
masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames]
else:
raise ValueError()
masks = [np.array(mk) for mk in masks]
masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w
# assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}'
masks = (masks > 0).int()
return masks, torch.ones(len(frames)).bool()
num_delted_frames = 0
for train_set_name in SET_NAME:
set_dir = SET_NAME_TO_DIR[train_set_name]
frames_dir = os.path.join(root, set_dir, 'Frame')
mask_dir = os.path.join(root, set_dir, 'GT')
video_ids = os.listdir(frames_dir)
for vid in tqdm(video_ids):
frames = [haosen[:-4] for haosen in os.listdir(os.path.join(frames_dir, vid))]
frame_has_fore = [get_frames_mask(mask_dir, vid, [haosen])[0].any() for haosen in tqdm(frames)] # list[t]
assert len(frame_has_fore) == len(frames)
num_delted_frames += (~ torch.tensor(frame_has_fore)).int().sum()
for haosen, frame_name in tqdm(zip(frame_has_fore, frames)):
if not haosen:
os.remove(os.path.join(frames_dir, vid, f'{frame_name}.jpg'))
if os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.jpg')):
os.remove(os.path.join(mask_dir, vid, f'{frame_name}.jpg'))
elif os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.png')):
os.remove(os.path.join(mask_dir, vid, f'{frame_name}.png'))
else:
raise ValueError()
print(f'should be {num_delted_frames}/1546.') # should be 1546
================================================
FILE: main.py
================================================
import os
import argparse
import logging
import importlib
from trainers import task_to_trainer
import detectron2.utils.comm as comm
from termcolor import colored
import logging
import yaml
import torch
from utils.misc import setup_for_distributed
def _highlight(code, filename):
try:
import pygments
except ImportError:
return code
from pygments.lexers import Python3Lexer, YamlLexer
from pygments.formatters import Terminal256Formatter
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
return code
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
message = record.message
# message, asctime, name, filename = record.message, record.asctime, record.name, record.filename
log = super(_ColorfulFormatter, self).formatMessage(record)
if (record.levelno == logging.WARNING) or (record.levelno == logging.ERROR) or (record.levelno == logging.CRITICAL):
colored_message = colored(message, "red", attrs=["blink", "underline"])
elif record.levelno == logging.DEBUG:
colored_message = colored(message, "yellow", attrs=["blink", "underline"])
else: # INFO/NOTSET
colored_message = colored(message, "white")
return log + colored_message
def set_logging_file(output_dir, file_name, mode='a'):
handler1 = logging.StreamHandler()
handler2 = logging.FileHandler(os.path.join(output_dir, file_name), mode=mode)
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s %(filename)s]: ", "green"),
datefmt="%m/%d %H:%M:%S",
root_name=os.path.join(output_dir, file_name),
abbrev_name=str('grey'),
)
handler1.setFormatter(formatter)
handler2.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.addHandler(handler2)
logger.setLevel(logging.DEBUG)
def init_process_group_and_set_device(world_size, process_id, device_id):
"""
This function needs to be called on each spawned process to initiate learning using DistributedDataParallel.
The function initiates the process' process group and assigns it a single GPU to use during training.
"""
torch.cuda.set_device(device_id)
device = torch.device(f'cuda:{device_id}')
if world_size > 1:
torch.distributed.init_process_group(
torch.distributed.Backend.NCCL,
world_size=world_size,
rank=process_id
)
comm.create_local_process_group(world_size)
torch.distributed.barrier(device_ids=[device_id])
setup_for_distributed(process_id == 0)
return device
def run(rank, configs, world_size):
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'] = "4"
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = "1"
os.environ["DGLBACKEND"] = "pytorch"
logging.getLogger('penman').setLevel(logging.WARNING)
logging.getLogger('PIL').setLevel(logging.WARNING)
logging.getLogger('PIL.PngImagePlugin').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('h5py').setLevel(logging.WARNING)
init_process_group_and_set_device(world_size, process_id=rank, device_id=rank)
if comm.is_main_process():
mode = configs['trainer_mode']
out_dir = configs['out_dir']
if mode == 'eval':
num_of_eval_times = len([eval_txt for eval_txt in os.listdir(out_dir) if eval_txt.endswith('eval.txt')])
set_logging_file(out_dir, f"eval.txt", mode='a')
path = os.path.join(out_dir, f"config_eval.yaml")
else:
num_of_train_times = len([train_txt for train_txt in os.listdir(out_dir) if train_txt.endswith('train.txt')])
if 'resume' in mode:
set_logging_file(out_dir, f"train.txt", mode='a')
else:
set_logging_file(out_dir, f"train.txt", mode='w')
path = os.path.join(out_dir, f"config_train.yaml")
logging.debug("Running with full config:\n{}".format(_highlight(yaml.dump(configs, default_flow_style=False), ".yaml")))
with open(path, "w") as f:
f.write(yaml.dump(configs, default_flow_style=False))
logging.debug("Full config saved to {}".format(path))
comm.synchronize()
trainer = task_to_trainer[configs['task']](configs=configs)
comm.synchronize()
if configs['trainer_mode'] == 'eval':
eval_ckpts = configs['eval_ckpts']
for lunch in eval_ckpts:
trainer.load_ckpt(lunch, load_model=True, load_schedule=True, load_random=False, load_optimize=False)
trainer.evaluate()
else:
if configs['trainer_mode'] == 'train_resume':
ckpt_dirs = os.listdir(configs['out_dir'])
ckpt_dirs = sorted([a for a in ckpt_dirs if a.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1]))
trainer_ckpt = '/'.join([configs['out_dir'], ckpt_dirs[-1], 'ckpt.pth.tar'])
trainer.load_ckpt(trainer_ckpt, load_model=True, load_schedule=True, load_random=True, load_optimize=True)
trainer.train()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config_file', type=str, required=True)
parser.add_argument('--trainer_mode', type=str, default='train_attmpt')
parser.add_argument('--eval_path', type=str, default='')
args = parser.parse_args()
task, group, config, config2 = args.config_file.split('/')[-4:]
assert config == config2[:-3]
config_file = '.'.join(['output', task, group, config, config])
configs = importlib.import_module(config_file).trainer_configs
configs['task'], configs['group'], configs['config'] = task, group, config
configs['out_dir'] = os.path.join('./', 'output', task, group, config)
configs['trainer_mode'] = args.trainer_mode
if configs['trainer_mode'] == 'eval':
eval_ckpts = []
eval_path = args.eval_path
assert eval_path != '', f'eval path is none'
if os.path.isfile(eval_path):
eval_ckpts.append(eval_path)
elif os.path.isdir(eval_path):
ckpt_dirs = os.listdir(eval_path)
ckpt_dirs = [taylor for taylor in ckpt_dirs if os.path.isdir(os.path.join(eval_path, taylor))]
# epc[1]_iter[5000]_sap[60009]
ckpt_dirs = sorted([billie for billie in ckpt_dirs if billie.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1]))
eval_ckpts = [os.path.join(eval_path, cd, f'ckpt.pth.tar') for cd in ckpt_dirs]
eval_ckpts = [eval_c for eval_c in eval_ckpts if os.path.exists(eval_c)]
else:
raise ValueError()
configs['eval_ckpts'] = eval_ckpts
else:
pass
gpu_ids = list(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
assert len(set(gpu_ids)) == len(gpu_ids)
gpu_ids = list(range(len(gpu_ids)))
if len(gpu_ids) > 1:
torch.multiprocessing.spawn(run, nprocs=len(gpu_ids), args=(configs, len(gpu_ids)))
elif len(gpu_ids) == 1:
run(rank=0, configs=configs, world_size=len(gpu_ids))
================================================
FILE: models/VIS/BackboneEncoderDecoder_WithScaleConsistency.py
================================================
import matplotlib.pyplot as plt
from typing import Any, Optional, List, Dict, Set, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_schedule import build_schedule
from torch import Tensor
from einops import repeat, rearrange, reduce
from functools import partial
from einops.layers.torch import Rearrange
from torch import einsum
import numpy as np
import logging
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video, VIS_Aug_CallbackAPI
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann
import torchvision.transforms.functional as Trans_F
import copy
from models.registry import register_model
from models.optimization.optimizer import get_optimizer
from models.optimization.scheduler import build_scheduler
from models.backbone.utils import VideoMultiscale_Shape
from detectron2.modeling import BACKBONE_REGISTRY, META_ARCH_REGISTRY
class BackboneEncoderDecoder_WithScaleConsistency(nn.Module):
def __init__(
self,
configs,
pixel_mean = [0.485, 0.456, 0.406],
pixel_std = [0.229, 0.224, 0.225],):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) # 3 1 1
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
self.loss_weight = configs['model']['loss_weight']
video_backbone_configs = configs['model']['video_backbone']
video_backbone_cls = BACKBONE_REGISTRY.get(video_backbone_configs['name'])
self.video_backbone = video_backbone_cls(video_backbone_configs)
self.max_stride = self.video_backbone.max_stride
self.fusion_encoder = META_ARCH_REGISTRY.get(configs['model']['fusion']['name'])(configs['model']['fusion'],
multiscale_shapes=self.video_backbone.multiscale_shapes)
same_dim_multiscale_shapes = VideoMultiscale_Shape.set_multiscale_same_dim(shape_by_dim=self.video_backbone.multiscale_shapes,
same_dim=configs['model']['fusion']['d_model'])
self.decoder = META_ARCH_REGISTRY.get(configs['model']['decoder']['name'])(configs['model']['decoder'],
multiscale_shapes=same_dim_multiscale_shapes)
if configs['model']['fusion']['name'] == 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_v2':
self.fusion_encoder.hack_ref(query_norm=self.decoder.temporal_query_norm, mask_mlp=self.decoder.query_mask)
self.test_clip_size = configs['model']['test_clip_size']
@property
def device(self):
return self.pixel_mean.device
def model_preds(self, videos, video_aux_dict,):
if (not self.training) and (self.test_clip_size is not None):
nf = videos.shape[2]
clip_outputs = [] # list[dict]
for start_idx in range(0, nf, self.test_clip_size):
multiscales = self.video_backbone(x=videos[:, :, start_idx:(start_idx + self.test_clip_size)]) # b c t h w
multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict)
clip_outputs.append(self.decoder(multiscales, video_aux_dict=video_aux_dict)[-1]) # b t nq h w
return [{
'pred_masks': torch.cat([haosen['pred_masks'] for haosen in clip_outputs], dim=1), # b t n h w
'pred_class': torch.cat([haosen['pred_class'] for haosen in clip_outputs], dim=1),
}]
# b 3 t h w -> b 3 t h w
multiscales = self.video_backbone(x=videos) # b c t h w
multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict)
return self.decoder(multiscales, video_aux_dict=video_aux_dict)
def forward(self, batch_dict):
assert self.training
VIS_TrainAPI_clipped_video
videos = batch_dict['video_dict']['videos']
targets = batch_dict['targets']
batch_size, nf = videos.shape[:2]
videos = (videos - self.pixel_mean) / self.pixel_std
size1 = np.random.choice([256, 288, 320, 352, 384, 416, 448])
vid_1 = F.interpolate(videos.flatten(0, 1), size=size1, mode='bilinear')
vid_1 = rearrange(vid_1, '(b T) c h w -> b c T h w',b=batch_size, T=nf)
pred1 = self.model_preds(vid_1, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w}
pred1_loss = self.decoder.compute_loss(pred1, targets=targets, frame_targets=batch_dict['frame_targets'],
video_aux_dict=batch_dict['video_dict'])
loss_value_dict = {key: pred1_loss[key] for key in list(self.loss_weight.keys())}
return loss_value_dict, self.loss_weight
@torch.no_grad()
def sample(self, batch_dict):
assert not self.training
VIS_EvalAPI_clipped_video_request_ann
videos = batch_dict['video_dict']['videos'] # b t 3 h w, 0-1
orig_t, _, orig_h, orig_w = batch_dict['video_dict']['orig_sizes'][0]
videos = (videos - self.pixel_mean) / self.pixel_std
assert videos.shape[0] == 1
batch_size, T, _, H, W = videos.shape
videos = videos.permute(0, 2, 1,3,4) # b c t h w
decoder_output = self.model_preds(videos, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w}
if isinstance(decoder_output, list):
decoder_output = decoder_output[-1]
pred_masks = decoder_output['pred_masks'][0] # T n h w
pred_masks = F.interpolate(pred_masks, size=(H, W), mode='bilinear') > 0 # T n h w
pred_masks = pred_masks[:orig_t, :, :orig_h, :orig_w] # T n h w
#
pred_classes = decoder_output['pred_class'][0][:orig_t, :,:] # T n c, probability
pred_classes = pred_classes.cpu().unbind(0) # list[n c], T
pred_masks = pred_masks.cpu().unbind(0) # list[n h w], T
VIS_Aug_CallbackAPI
orig_video = videos[0][:, :orig_t, :orig_h, :orig_w].permute(1,0,2,3) # T 3 h w
orig_video = Trans_F.normalize(orig_video, [0, 0, 0], 1 / self.pixel_std)
orig_video = Trans_F.normalize(orig_video, -self.pixel_mean, [1, 1, 1]).cpu()
return {
'video': [orig_video], # [t 3 h w], 1
'pred_masks': [pred_masks], # [list[n h w], t, bool], 1
'pred_class': [pred_classes], # [list[n c], t, probability], 1
}
@staticmethod
def get_optim_params_group(model, configs):
weight_decay_norm = configs['optim']['weight_decay_norm']
weight_decay_embed = configs['optim']['weight_decay_embed']
defaults = {}
defaults['lr'] = configs['optim']['base_lr']
defaults['weight_decay'] = configs['optim']['weight_decay']
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
log_lr_group_idx = {'backbone':None, 'base':None}
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if "video_backbone" in module_name:
hyperparams["lr"] = hyperparams["lr"] * configs['optim']['backbone_lr_multiplier']
if log_lr_group_idx['backbone'] is None:
log_lr_group_idx['backbone'] = len(params)
else:
if log_lr_group_idx['base'] is None:
log_lr_group_idx['base'] = len(params)
# pos_embed, norm, embedding的weight decay特殊对待
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
):
logging.debug(f'setting weight decay of {module_name}.{module_param_name} to zero')
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
params.append({"params": [value], **hyperparams})
return params, log_lr_group_idx
@register_model
def backbone_encoder_decoder_withScaleConsistency(configs, device):
from .aux_mapper import AUXMapper_v1
model = BackboneEncoderDecoder_WithScaleConsistency(configs)
model.to(device)
params_group, log_lr_group_idx = BackboneEncoderDecoder_WithScaleConsistency.get_optim_params_group(model=model, configs=configs)
to_train_num_parameters = len([n for n, p in model.named_parameters() if p.requires_grad])
assert len(params_group) == to_train_num_parameters, f''
optimizer = get_optimizer(params_group, configs)
scheduler = build_scheduler(configs=configs, optimizer=optimizer)
model_input_mapper = AUXMapper_v1(configs['model']['input_aux'])
train_samplers, train_loaders, eval_function = build_schedule(configs,
model_input_mapper.mapper,
partial(model_input_mapper.collate, max_stride=model.max_stride))
return model, optimizer, scheduler, train_samplers, train_loaders, log_lr_group_idx, eval_function
================================================
FILE: models/VIS/__init__.py
================================================
from . import BackboneEncoderDecoder_WithScaleConsistency
from .. import modality_input_mappers
from .. import backbone
from .. import decoder
from .. import encoder
================================================
FILE: models/VIS/aux_mapper.py
================================================
import torch
from torch.nn import functional as F
from models.registry import register_model
from data_schedule.utils.box_ops import box_xyxy_to_cxcywh
from models.registry import MODELITY_INPUT_MAPPER_REGISTRY
from data_schedule.vis.apis import VIS_TrainAPI_clipped_video
from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann
from utils.misc import nested_tensor_from_videos_list_with_stride
class AUXMapper_v1:
def __init__(self, aux_configs):
video_auxes = aux_configs['video_auxes']
video_auxes_names = [config['name'] for config in video_auxes]
assert len(list(set(video_auxes_names))) == len(video_auxes_names), '每个aux的名字必须不一样'
self.video_auxes_names = video_auxes_names
self.video_auxes = [MODELITY_INPUT_MAPPER_REGISTRY.get(config['name'])(config) for config in video_auxes]
self.targets_auxes = None
def mapper(self, data_dict, mode,):
if mode == 'train':
VIS_TrainAPI_clipped_video
video = data_dict['video_dict']['video']
for aux, aux_name in zip(self.video_auxes, self.video_auxes_names):
data_dict['video_dict'][aux_name] = aux.mapper(video)
elif mode == 'evaluate':
VIS_EvalAPI_clipped_video_request_ann
video = data_dict['video_dict']['video']
for aux, aux_name in zip(self.video_auxes, self.video_auxes_names):
data_dict['video_dict'][aux_name] = aux.mapper(video)
else:
raise ValueError()
return data_dict
def collate(self, batch_dict, mode, max_stride):
if mode == 'train':
VIS_TrainAPI_clipped_video
video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride)
targets = [sample['targets'] for sample in batch_dict]
frame_has_ann = [clip_tgt['has_ann'] for clip_tgt in targets] # list[t], b
frame_targets = [sample['frame_targets'] for sample in batch_dict]
_, pad_T, _, pad_H, pad_W = video_dict['videos'].shape
targets = self.collate_targets(targets=targets, pad_H=pad_H, pad_W=pad_W, pad_T=pad_T)
frame_targets = self.collate_frame_targets(frame_targets=frame_targets,
frame_has_ann=frame_has_ann,
pad_H=pad_H, pad_W=pad_W, pad_T=pad_T)
ret = {
'video_dict': video_dict,
'targets': targets,
'frame_targets': frame_targets,
'meta_idxs': [sample['meta_idx'] for sample in batch_dict],
'visualize': [sample['visualize'] for sample in batch_dict],
}
elif mode == 'evaluate':
VIS_EvalAPI_clipped_video_request_ann
assert len(batch_dict) == 1
video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride) # 不pad
metas = [sample['meta'] for sample in batch_dict]
collated_metas = {}
for key in metas[0].keys():
collated_metas[key] = [mt[key] for mt in metas]
ret = {
'video_dict': video_dict,
'metas': collated_metas,
'meta_idxs': [sample['meta_idx'] for sample in batch_dict],
'visualize': [sample['visualize'] for sample in batch_dict],
}
debug_data = False
if debug_data:
self.visualize_input_target_for_debug_data(ret) # ./test.png
return ret
def collate_video_dict(self, batch_dict, max_stride):
videos = [sample['video_dict']['video'] for sample in batch_dict] # list[ti 3 hi wi] -> b T 3 H W
orig_sizes = [list(vid.shape) for vid in videos] # t 3 h w
if type(max_stride) == int: # temporal max stride 为1, spatial max stride
pad_stride = [1, max_stride]
if (type(max_stride) == list) and (len(max_stride) == 2):
pad_stride = max_stride
videos = nested_tensor_from_videos_list_with_stride(videos, max_stride=pad_stride).tensors # b t c h w
video_dicts = {'videos': videos, 'orig_sizes': orig_sizes}
for aux_name, aux in zip(self.video_auxes_names, self.video_auxes):
auxes = [sample['video_dict'][aux_name] for sample in batch_dict] # list[dict] / list[tensor]
collated_auxes = aux.collate(auxes, batch_videos=videos) # list[dict] / tensor
if isinstance(auxes[0], dict):
keys = collated_auxes.keys()
for key in keys:
assert key not in video_dicts
video_dicts[key] = collated_auxes[key]
else:
video_dicts[aux_name] = collated_auxes
return video_dicts
def collate_frame_targets(self, frame_targets, frame_has_ann, pad_H, pad_W, pad_T): #
VIS_TrainAPI_clipped_video
ret = {}
has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in frame_has_ann], dim=0).flatten() # bT
ret['has_ann'] = has_ann
masks = [ftarget['masks'] for sample in frame_targets for ftarget in sample] # list[ni h w], bt'
masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2])).bool() for m in masks] # list[ni H W], bt'
ret['masks'] = masks # list[ni h w], bt'
classes = [ftarget['classes'] for sample in frame_targets for ftarget in sample] # list[ni], bt'
ret['classes'] = classes
if 'boxes' in frame_targets[0][0]:
boxes = [ftarget['boxes'] for sample in frame_targets for ftarget in sample] # list[ni 4], x1y1x2y2, bt'
boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes]
boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=bx.dtype) for bx in boxes] # 0-1
ret['boxes'] = boxes # list[ni 4], bt'
return ret
def collate_targets(self, targets, pad_H, pad_W, pad_T):
VIS_TrainAPI_clipped_video
has_ann = [sample['has_ann'] for sample in targets] # list[t], bool
has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in has_ann], dim=0) # b T
masks = [sample['masks'] for sample in targets]
masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2]), value=0.).bool() \
for m in masks] # list[ni T' H W]
classes = [sample['classes'] for sample in targets]
ret = {
'masks': masks, # list[ni T' h w]
'has_ann': has_ann, # b T
'classes': classes, # list[ni], b
}
if 'boxes' in targets[0]:
boxes = [sample['boxes'] for sample in targets] # list[ni T' 4], x1y1x2y2
boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes]
boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=torch.float) for bx in boxes] # 0-1
ret.update({'boxes': boxes,})
return ret
def visualize_input_target_for_debug_data(self, ret):
videos = ret['video_dict']['videos'] # b T 3 H W
pass
================================================
FILE: models/__init__.py
================================================
import os
from .registry import model_entrypoint
if os.getenv('CURRENT_TASK') == 'VIS':
from . import VIS
else:
raise ValueError()
================================================
FILE: models/backbone/__init__.py
================================================
from . import res2net, pvtv2
================================================
FILE: models/backbone/pvtv2.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from timm.models.registry import register_model
class DWConv(nn.Module):
def __init__(self, dim):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x, H, W):
B,N,C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x, H, W):
x = self.fc1(x)
x = F.gelu(self.dwconv(x, H, W))
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads, sr_ratio):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.num_heads = num_heads
self.scale = (dim//num_heads)**(-0.5)
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio, drop_path, sr_ratio):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio))
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size, stride, in_chans, embed_dim):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size//2, patch_size//2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
B,C,H,W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PVT(nn.Module):
def __init__(self, embed_dims, mlp_ratios, depths, snapshot, sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.depths = depths
self.snapshot = snapshot
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=3, embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(dim=embed_dims[0], num_heads=1, mlp_ratio=mlp_ratios[0], drop_path=dpr[cur + i], sr_ratio=sr_ratios[0]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0], eps=1e-6)
cur += depths[0]
self.block2 = nn.ModuleList([Block(dim=embed_dims[1], num_heads=2, mlp_ratio=mlp_ratios[1], drop_path=dpr[cur + i], sr_ratio=sr_ratios[1]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
cur += depths[1]
self.block3 = nn.ModuleList([Block(dim=embed_dims[2], num_heads=5, mlp_ratio=mlp_ratios[2], drop_path=dpr[cur + i], sr_ratio=sr_ratios[2]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
cur += depths[2]
self.block4 = nn.ModuleList([Block(dim=embed_dims[3], num_heads=8, mlp_ratio=mlp_ratios[3], drop_path=dpr[cur + i], sr_ratio=sr_ratios[3]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
state_dict:dict = torch.load(self.snapshot, map_location='cpu')
state_dict.pop("head.weight")
state_dict.pop("head.bias")
self.load_state_dict(state_dict, strict=True)
del state_dict
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def forward(self, x):
B = x.shape[0]
# stage 1
out1, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
out1 = blk(out1, H, W)
out1 = self.norm1(out1).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 2
out2, H, W = self.patch_embed2(out1)
for i, blk in enumerate(self.block2):
out2 = blk(out2, H, W)
out2 = self.norm2(out2).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 3
out3, H, W = self.patch_embed3(out2)
for i, blk in enumerate(self.block3):
out3 = blk(out3, H, W)
out3 = self.norm3(out3).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 4
out4, H, W = self.patch_embed4(out3)
for i, blk in enumerate(self.block4):
out4 = blk(out4, H, W)
out4 = self.norm4(out4).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return out1, out2, out3, out4
from detectron2.modeling import BACKBONE_REGISTRY
from einops import rearrange, reduce, repeat
from .utils import VideoMultiscale_Shape, ImageMultiscale_Shape
import os
import time
@BACKBONE_REGISTRY.register()
class PVT_V2(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
pt_path = os.getenv('PT_PATH')
pvt_v2 = PVT(embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 6, 3], snapshot=os.path.join(pt_path, 'pvt_v2/pvt_v2_b2.pth'))
self.pvt_v2 = pvt_v2
freeze = configs['freeze']
if freeze:
for p in self.parameters():
p.requires_grad_(False)
self.multiscale_shapes = {}
for name, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[4, 8, 16, 32],
[64, 128, 320, 512]):
self.multiscale_shapes[name] = ImageMultiscale_Shape(spatial_stride=spatial_stride, dim=dim)
self.max_stride = 32
def forward(self, x):
if not self.training:
batch_feats = []
for haosen in x:
feats = self.pvt_v2(haosen.unsqueeze(0))
batch_feats.append(feats)
batch_feats = list(zip(*batch_feats)) # 4
batch_feats = [torch.cat(haosen, dim=0) for haosen in batch_feats] # list[bt c h w]
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, batch_feats):
ret[name] = feat
return ret
else:
layer_outputs = self.pvt_v2(x)
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, layer_outputs):
ret[name] = feat
return ret
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@BACKBONE_REGISTRY.register()
class Video2D_PVT_V2(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
self.image_homo = PVT_V2(configs=configs)
self.multiscale_shapes = {}
for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[1, 1, 1, 1],
[4, 8, 16, 32],
[64, 128, 320, 512]):
self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride,
spatial_stride=spatial_stride, dim=dim)
self.max_stride = [1, 32]
def forward(self, x):
batch_size, _, T = x.shape[:3]
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
layer_outputs = self.image_homo(x)
layer_outputs = {key: rearrange(value.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T).contiguous() \
for key, value in layer_outputs.items()}
return layer_outputs
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
================================================
FILE: models/backbone/res2net.py
================================================
import math
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from detectron2.modeling import BACKBONE_REGISTRY
from einops import rearrange, reduce, repeat
from .utils import VideoMultiscale_Shape
class Bottle2neck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
super(Bottle2neck, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = 1 if scale == 1 else scale - 1
if stype == 'stage':
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
convs, bns = [], []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.downsample = downsample
self.stype = stype
self.scale = scale
self.width = width
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
sp = spx[i] if i == 0 or self.stype == 'stage' else sp + spx[i]
sp = self.convs[i](sp)
sp = F.relu(self.bns[i](sp), inplace=True)
out = sp if i == 0 else torch.cat((out, sp), 1)
if self.scale != 1 and self.stype == 'normal':
out = torch.cat((out, spx[self.nums]), 1)
elif self.scale != 1 and self.stype == 'stage':
out = torch.cat((out, self.pool(spx[self.nums])), 1)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
x = self.downsample(x)
return F.relu(out+x, inplace=True)
class Res2Net(nn.Module):
def __init__(self, layers, snapshot, baseWidth=26, scale=4):
super(Res2Net, self).__init__()
self.inplanes = 64
self.snapshot = snapshot
self.baseWidth = baseWidth
self.scale = scale
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 1, 1, bias=False)
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(Bottle2neck, 64, layers[0])
self.layer2 = self._make_layer(Bottle2neck, 128, layers[1], stride=2)
self.layer3 = self._make_layer(Bottle2neck, 256, layers[2], stride=2)
self.layer4 = self._make_layer(Bottle2neck, 512, layers[3], stride=2)
state_dict:dict = torch.load(self.snapshot, map_location='cpu')
self.load_state_dict(state_dict, strict=False)
del state_dict
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample=downsample, stype='stage', baseWidth=self.baseWidth, scale=self.scale)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))
return nn.Sequential(*layers)
def forward(self, x):
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
out2 = self.layer1(out1)
out3 = self.layer2(out2)
out4 = self.layer3(out3)
out5 = self.layer4(out4)
return out2, out3, out4, out5
def initialize(self):
self.load_state_dict(torch.load(self.snapshot), strict=False)
@BACKBONE_REGISTRY.register()
class Res2Net_50_EachFrame(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
pt_path = os.getenv('PT_PATH')
res2net = Res2Net([3, 4, 6, 3], os.path.join(pt_path, 'res2net/res2net50_v1b_26w_4s-3cf99910.pth'))
self.res2net = res2net
freeze = configs['freeze']
if freeze:
for p in self.parameters():
p.requires_grad_(False)
self.multiscale_shapes = {}
for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'],
[1, 1, 1, 1],
[4, 8, 16, 32],
[256, 512, 1024, 2048]):
self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride,
spatial_stride=spatial_stride, dim=dim)
self.max_stride = [1, 32]
def forward(self, x):
batch_size, _, T = x.shape[:3]
x = rearrange(x, 'b c t h w -> (b t) c h w')
layer_outputs = self.res2net(x)
ret = {}
names = ['res2', 'res3', 'res4', 'res5']
for name, feat in zip(names, layer_outputs):
ret[name] = rearrange(feat.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T)
return ret
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
================================================
FILE: models/backbone/utils.py
================================================
class VideoMultiscale_Shape:
def __init__(self, temporal_stride, spatial_stride, dim) -> None:
self.temporal_stride = temporal_stride
self.spatial_stride = spatial_stride
self.dim = dim
@staticmethod
def set_multiscale_same_dim(shape_by_dim, same_dim):
return {
key: VideoMultiscale_Shape(temporal_stride=value.temporal_stride,
spatial_stride=value.spatial_stride,
dim=same_dim) for key,value in shape_by_dim.items()
}
class ImageMultiscale_Shape:
def __init__(self, spatial_stride, dim) -> None:
self.spatial_stride = spatial_stride
self.dim = dim
================================================
FILE: models/decoder/__init__.py
================================================
from . import mask2former_video
================================================
FILE: models/decoder/mask2former_video.py
================================================
# multi-scale features, b c h w -> module -> obj queries, predictions, b nq c
import torch.nn as nn
from models.layers.decoder_layers import CrossAttentionLayer, SelfAttentionLayer, FFNLayer
from models.layers.anyc_trans import MLP
import torch.nn.functional as F
import torch
import copy
from models.layers.utils import zero_module, _get_clones
from models.layers.position_encoding import build_position_encoding
from einops import rearrange, reduce, repeat
from scipy.optimize import linear_sum_assignment
from models.layers.matching import batch_dice_loss, batch_sigmoid_ce_loss, batch_sigmoid_focal_loss, dice_loss, ce_mask_loss
from detectron2.modeling import META_ARCH_REGISTRY
import detectron2.utils.comm as comm
import data_schedule.utils.box_ops as box_ops
from models.layers.utils import zero_module
from utils.misc import is_dist_avail_and_initialized
from collections import defaultdict
from detectron2.projects.point_rend.point_features import point_sample
from torch.cuda.amp import autocast
from detectron2.projects.point_rend.point_features import (
get_uncertain_point_coords_with_randomness,
point_sample,
)
def calculate_uncertainty(logits):
"""
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
foreground class in `classes`.
Args:
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images and C is
the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
assert logits.shape[1] == 1
gt_class_logits = logits.clone()
return -(torch.abs(gt_class_logits))
class Video_SetMatchingLoss(nn.Module):
def __init__(self,
loss_config,
num_classes,) -> None:
super().__init__()
self.num_classes = num_classes # n=1 / n=0 / n>1
self.matching_metrics = loss_config['matching_metrics'] # mask: mask/dice; point_sample_mask: ..
self.losses = loss_config['losses']
self.aux_layer_weights = loss_config['aux_layer_weights'] # int/list
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = loss_config['background_cls_eos']
self.register_buffer('empty_weight', empty_weight)
# self.register_buffer('small_obj_weight', torch.tensor(loss_config['small_obj_weight']).float())
self._warmup_iters = 2000
self.register_buffer("_iter", torch.zeros([1]))
@property
def device(self,):
return self.empty_weight.device
def compute_loss(self,
model_outs,
targets,
video_aux_dict,
**kwargs):
# list[n t' h w], batch
if 'masks' in targets:
num_objs = sum([haosen.flatten(1).any(-1).int().sum().item() for haosen in targets['masks']])
# list[n t' 4], batch
elif 'boxes' in targets:
# n t' 2 -> n t -> n
num_objs = sum([(haosen[:, :, 2:] > 0).all(-1).any(-1).int().sum().item() for haosen in targets['boxes']])
else:
raise ValueError('targets里没有boxes/masks, 需要确定数量')
num_objs = torch.as_tensor([num_objs], dtype=torch.float, device=self.device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_objs)
num_objs = torch.clamp(num_objs / comm.get_world_size(), min=1).item()
if isinstance(self.aux_layer_weights, list):
assert len(self.aux_layer_weights) == (len(model_outs) - 1)
else:
self.aux_layer_weights = [self.aux_layer_weights] * (len(model_outs) - 1)
layer_weights = self.aux_layer_weights + [1.]
loss_values = {
'mask_dice':0., 'mask_ce':0.,
'box_l1': 0., 'box_giou': 0.,
'class_ce':0.,
'mask_dice_smobj':0., 'mask_ce_smobj':0.,
'boxMask_dice':0., 'boxMask_ce':0.,
}
if ('mask_ce_dice' in self.matching_metrics) or ('mask_ce_dice' in self.losses):
# mask interpolate
tgt_mask_shape = targets['masks'][0].shape[-2:] # list[n t H W], b
for layer_idx in range(len(model_outs)):
# b t nq h w
batch_size, nf = model_outs[layer_idx]['pred_masks'].shape[:2]
model_outs[layer_idx]['pred_masks'] = rearrange(F.interpolate(model_outs[layer_idx]['pred_masks'].flatten(0, 1),
size=tgt_mask_shape, mode='bilinear', align_corners=False),
'(b t) n h w -> b t n h w',b=batch_size, t=nf)
for taylor, layer_out in zip(layer_weights, model_outs):
if taylor != 0:
matching_indices = self.matching(layer_out, targets)
for loss in self.losses:
loss_extra_param = self.losses[loss]
if loss == 'mask_dice_ce' :
loss_dict = self.loss_mask_dice_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
elif loss == 'class_ce':
loss_dict = self.loss_class_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
elif loss == 'point_mask_dice_ce':
loss_dict = self.loss_point_mask_dice_ce(layer_out, targets, matching_indices, num_objs,
loss_extra_param=loss_extra_param)
else:
raise ValueError()
for key, value in loss_dict.items():
loss_values[key] = loss_values[key] + value
return loss_values
@torch.no_grad()
def matching(self, layer_out, targets):
batch_size = len(targets['masks']) if 'masks' in targets else len(targets['boxes'])
indices = []
has_ann = targets['has_ann']
for i in range(batch_size):
C = 0.
if 'class_prob' in self.matching_metrics:
out_cls = layer_out['pred_class'][i].softmax(-1) # nq c
tgt_cls = targets['classes'][i] # n
cost_class = - out_cls[:, tgt_cls] # nq n
C += self.matching_metrics['class_prob']['prob'] * cost_class
if 'mask_dice_ce' in self.matching_metrics:
out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w
tgt_mask = targets['masks'][i].to(out_mask) # ni t' H W
cost_mask = batch_sigmoid_ce_loss(out_mask.flatten(1), tgt_mask.flatten(1))
cost_dice = batch_dice_loss(out_mask.flatten(1), tgt_mask.flatten(1))
C += self.matching_metrics['mask_dice_ce']['ce'] * cost_mask + \
self.matching_metrics['mask_dice_ce']['dice'] * cost_dice
if 'point_mask_dice_ce' in self.matching_metrics:
out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w
tgt_mask = targets['masks'][i].to(out_mask)# ni t' H W
nf = out_mask.shape[1]
out_mask = out_mask.flatten(0, 1)[:, None]
tgt_mask = tgt_mask.flatten(0, 1)[:, None]
# all masks share the same set of points for efficient matching!
point_coords = torch.rand(1, self.matching_metrics['point_mask_dice_ce']['num_points'],
2, device=self.device)
# get gt labels
tgt_mask = point_sample(
tgt_mask,
point_coords.repeat(tgt_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1) # nqt s
tgt_mask = rearrange(tgt_mask, '(nq t) s -> nq t s',t=nf)
out_mask = point_sample(
out_mask,
point_coords.repeat(out_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1) # nit s
out_mask = rearrange(out_mask, '(nq t) s -> nq t s',t=nf)
with autocast(enabled=False):
out_mask = out_mask.float().flatten(1) # nq num_points
tgt_mask = tgt_mask.float().flatten(1)
cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask)
cost_dice = batch_dice_loss(out_mask, tgt_mask)
C += self.matching_metrics['point_mask_dice_ce']['ce'] * cost_mask + \
self.matching_metrics['point_mask_dice_ce']['dice'] * cost_dice
C = C.cpu()
indices.append(linear_sum_assignment(C))
return [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]
def loss_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
has_ann = targets['has_ann'] # b t
src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w
tgt_masks = targets['masks'] # list[n t' h w]
# list[nq t' h w] -> n_sigma t' h w
src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0)
tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0)
tgt_masks = tgt_masks.to(src_masks)
losses = {
"mask_ce": ce_mask_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs),
"mask_dice": dice_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs),
}
return losses
def loss_point_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
has_ann = targets['has_ann'] # b t
src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w
tgt_masks = targets['masks'] # list[n t' h w]
# list[nq t' h w] -> n_sigma t' h w
src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0)
tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0)
tgt_masks = tgt_masks.to(src_masks)
nf = src_masks.shape[1]
# No need to upsample predictions as we are using normalized coordinates :)
# NT x 1 x H x W
src_masks = src_masks.flatten(0, 1).unsqueeze(1).contiguous() # nt' 1 h w
target_masks = tgt_masks.flatten(0, 1).unsqueeze(1).contiguous()
with torch.no_grad():
# sample point_coords
point_coords = get_uncertain_point_coords_with_randomness(
src_masks,
lambda logits: calculate_uncertainty(logits),
loss_extra_param['num_points'],
loss_extra_param['oversample_ratio'],
loss_extra_param['importance_sample_ratio'],
)
# get gt labels
point_labels = point_sample(
target_masks,
point_coords,
align_corners=False,
).squeeze(1) # nt' s
point_logits = point_sample(
src_masks,
point_coords,
align_corners=False,
).squeeze(1) # nt' s
# point_logits = rearrange(point_logits, '(n t) s -> n (t s)',t=nf)
# point_labels = rearrange(point_labels, '(n t) s -> n (t s)',t=nf)
losses = {
"mask_dice": ce_mask_loss(point_logits, point_labels, num_objs),
"mask_ce": dice_loss(point_logits, point_labels, num_objs),
}
del src_masks
del target_masks
return losses
def loss_class_ce(self, outputs, targets, indices, num_objs, loss_extra_param):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
src_logits = outputs["pred_class"].float() # b nq c
idx = self._get_src_permutation_idx(indices)
# list[n], b -> bn
target_classes_o = torch.cat([t[J] for t, (_, J) in zip(targets['classes'], indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=self.device
)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"class_ce": loss_ce}
return losses
def loss_box_l1_giou(self, outputs, targets, indices, num_objs, loss_extra_param):
tgt_boxes = targets['boxes'] # list[n tl 4], b
has_ann = targets['has_ann'] # b t
src_boxes = outputs['pred_boxes'].sigmoid().permute(0, 2, 1, 3).contiguous() # b nq t 4
src_boxes = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_boxes, indices, has_ann)], dim=0) # n_sum t' 4
tgt_boxes = torch.cat([t[J] for t, (_, J) in zip(tgt_boxes, indices)], dim=0) # n_sum t' 4
nf = tgt_boxes.shape[1]
loss_l1 = F.l1_loss(src_boxes, tgt_boxes, reduction='none').flatten(1) # n_sum t'4
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)),
box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) # n_sumt'
loss_giou = loss_giou.view(-1, nf).contiguous()
return {
'box_l1': loss_l1.sum(-1).sum() / num_objs,
'box_giou': loss_giou.sum(-1).sum() / num_objs
}
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
@META_ARCH_REGISTRY.register()
class Video_MaskedAttn_MultiscaleMaskDecoder_v3(nn.Module):
def __init__(self,
configs,
multiscale_shapes):
super().__init__()
d_model = configs['d_model']
attn_configs = configs['attn']
self.video_nqueries = configs['video_nqueries']
self.nlayers = configs['nlayers']
self.memory_scales = configs['memory_scales']
self.mask_scale = configs['mask_scale']
self.mask_spatial_stride = multiscale_shapes[self.mask_scale].spatial_stride
num_classes = configs['num_classes']
inputs_projs = configs['inputs_projs']
self.inputs_projs = nn.Sequential()
if inputs_projs is not None:
self.inputs_projs = META_ARCH_REGISTRY.get(inputs_projs['name'])(inputs_projs,
multiscale_shapes=multiscale_shapes,
out_dim=d_model)
self.level_embeds = nn.Embedding(len(self.memory_scales), d_model)
assert self.nlayers % len(self.memory_scales) == 0
self.cross_layers = _get_clones(CrossAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.self_layers = _get_clones(SelfAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.ffn_layers = _get_clones(FFNLayer(d_model=d_model,
dim_feedforward=attn_configs['dim_feedforward'],
dropout=0.0,
normalize_before=attn_configs['normalize_before']),
self.nlayers)
self.nheads = attn_configs['nheads']
self.temporal_query_poses = nn.Embedding(self.video_nqueries, d_model)
self.temporal_query_feats = nn.Embedding(self.video_nqueries, d_model)
self.temporal_query_norm = nn.LayerNorm(d_model)
self.pos_3d = build_position_encoding(hidden_dim=d_model, position_embedding_name='3d') # b t c h w
self.head_outputs = configs['head_outputs']
assert 'mask' in self.head_outputs
self.query_mask = MLP(d_model, d_model, d_model, 3)
if 'class' in self.head_outputs:
self.query_class = nn.Linear(d_model, num_classes + 1)
self.loss_module = Video_SetMatchingLoss(loss_config=configs['loss'], num_classes=num_classes)
@property
def device(self,):
return self.temporal_query_feats.weight.device
def get_memories_and_mask_features(self, multiscales):
# b c t h w
memories = [multiscales[scale] for scale in self.memory_scales]
size_list = [mem_feat.shape[-2:] for mem_feat in memories]
memories_poses = [self.pos_3d(mem.permute(0, 2, 1,3, 4)).permute(0, 2, 1, 3, 4) for mem in memories] # b c t h w
memories = [rearrange(mem, 'b c t h w -> (t h w) b c').contiguous() for mem in memories]
memories_poses = [rearrange(mem_pos, 'b c t h w -> (t h w) b c').contiguous() for mem_pos in memories_poses]
mask_features = multiscales[self.mask_scale] # b c t h w
return memories, memories_poses, mask_features, size_list
def forward(self,
multiscales, # b c t h w
video_aux_dict=None
):
multiscales = self.inputs_projs(multiscales[0])
# thw b c; b c t h w
memories, memories_poses, mask_features, size_list = self.get_memories_and_mask_features(multiscales)
memories = [mem_feat + self.level_embeds.weight[i][None, None, :] for i, mem_feat in enumerate(memories)]
batch_size, _, nf, *_ = mask_features.shape
# nq b c
temporal_query_poses = self.temporal_query_poses.weight.unsqueeze(1).repeat(1, batch_size, 1)
temporal_query_feats = self.temporal_query_feats.weight.unsqueeze(1).repeat(1, batch_size, 1)
vid_ret = []
# b nq class, b nq t h w; b*head nq thw
vid_class, vid_mask, attn_mask = \
self.forward_heads(temporal_query_feats=temporal_query_feats,
mask_features=mask_features, attn_mask_target_size=size_list[0]) # first sight you re not human
vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask})
for i in range(self.nlayers):
level_index = i % len(self.memory_scales)
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 全masked掉的 全注意, 比如有padding
temporal_query_feats = self.cross_layers[i](
tgt=temporal_query_feats, # nq b c
memory=memories[level_index], # thw b c
memory_mask=attn_mask, # b*h nq thw
memory_key_padding_mask=None,
pos=memories_poses[level_index], # thw b c
query_pos=temporal_query_poses, # nq b c
)
temporal_query_feats = self.self_layers[i](
temporal_query_feats,
tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=temporal_query_poses,
)
temporal_query_feats = self.ffn_layers[i](
temporal_query_feats
)
# b nq class, b nq t h w
vid_class, vid_mask, attn_mask = \
self.forward_heads(temporal_query_feats=temporal_query_feats,
mask_features=mask_features, attn_mask_target_size=size_list[(i + 1) % len(self.memory_scales)]) # first sight you re not human
vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask})
return vid_ret
def forward_heads(self, temporal_query_feats, mask_features, attn_mask_target_size): # nq b c; b c t h w
batch_size, _, nf, *_ = mask_features.shape
temporal_query_feats = self.temporal_query_norm(temporal_query_feats) # nq b c
temporal_query_feats = temporal_query_feats.transpose(0, 1).contiguous() # b nq c
class_logits = self.query_class(temporal_query_feats) if 'class' in self.head_outputs else None # b n class+1
mask_embeds = self.query_mask(temporal_query_feats) # b n c
mask_logits = torch.einsum("bqc,bcthw->bqthw", mask_embeds, mask_features)
batch_size, nq, nf = mask_logits.shape[:3]
mask_logits = F.interpolate(mask_logits.flatten(0, 1), scale_factor=self.mask_spatial_stride, mode='bilinear', align_corners=False)
mask_logits = rearrange(mask_logits, '(b n) t h w -> b t n h w',b=batch_size, n=nq)
# bt nq h w
attn_mask = mask_logits.detach().clone().flatten(0, 1)
attn_mask = (F.interpolate(attn_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) < 0.5).bool()
attn_mask = repeat(attn_mask, '(b t) nq h w -> (b head) nq (t h w)', b=batch_size, t=nf, head=self.nheads)
if self.training:
return class_logits, mask_logits, attn_mask
else:
return class_logits.softmax(-1).unsqueeze(1).repeat(1, nf, 1, 1) if class_logits is not None else None, mask_logits, attn_mask
def compute_loss(self, outputs, targets, video_aux_dict, **kwargs):
assert self.training
return self.loss_module.compute_loss(model_outs=outputs,
targets=targets,
video_aux_dict=video_aux_dict)
================================================
FILE: models/encoder/__init__.py
================================================
from . import localGlobal
from . import input_projs
from . import neighborhood_qk
================================================
FILE: models/encoder/input_projs.py
================================================
import torch.nn as nn
from detectron2.modeling import META_ARCH_REGISTRY
from models.layers.anyc_trans import Linear_NormAct
from models.layers.anyc_trans import Conv3d_NormAct, Conv2d_NormAct
from einops import rearrange
@META_ARCH_REGISTRY.register()
class VideoConv3d_TextLinear(nn.Module):
"""
如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim
如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim
"""
def __init__(self,
configs,
out_dim,
text_dim=None, # 如果是none的话, 那么假设等于out_dim
multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
text_dim = out_dim if text_dim is None else text_dim
multiscale_projs_config = configs['video_multiscale_projs']
proj_names = multiscale_projs_config.keys() # list[str]
in_dims = {}
if multiscale_shapes is not None:
assert set(proj_names).issubset(set(list(multiscale_shapes.keys())))
for name in proj_names:
in_dims[name] = multiscale_shapes[name].dim
else:
for name in proj_names:
in_dims[name] = out_dim
projs = {}
for name, config in multiscale_projs_config.items():
projs[name] = Conv3d_NormAct(in_channels=in_dims[name],
out_channels=out_dim,
**config)
self.video_multiscale_projs = nn.ModuleDict(projs)
text_proj_config = configs['text_proj']
if text_proj_config is None:
self.text_proj = nn.Identity()
else:
self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config)
def forward(self, multiscales, text_dict):
ret = {}
for scale_name, scale_feat in multiscales.items(): # b c t h w
if scale_name in self.video_multiscale_projs:
scale_feat = self.video_multiscale_projs[scale_name](scale_feat)
ret[scale_name] = scale_feat
else:
ret[scale_name] = scale_feat
if isinstance(text_dict, AMRData):
text_dict.amr_feats = self.text_proj(text_dict.amr_feats)
text_dict.text_feats = self.text_proj(text_dict.text_feats)
else:
raise ValueError()
return ret, text_dict
@META_ARCH_REGISTRY.register()
class VideoConv2d_TextLinear(nn.Module):
"""
如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim
如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim
"""
def __init__(self,
configs,
out_dim,
text_dim=None, # 如果是none的话, 那么假设等于out_dim
multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
text_dim = out_dim if text_dim is None else text_dim
multiscale_projs_config = configs['video_multiscale_projs']
proj_names = multiscale_projs_config.keys() # list[str]
in_dims = {}
if multiscale_shapes is not None:
assert set(proj_names).issubset(set(list(multiscale_shapes.keys())))
for name in proj_names:
in_dims[name] = multiscale_shapes[name].dim
else:
for name in proj_names:
in_dims[name] = out_dim
projs = {}
for name, config in multiscale_projs_config.items():
projs[name] = Conv2d_NormAct(in_channels=in_dims[name],
out_channels=out_dim,
**config)
self.video_multiscale_projs = nn.ModuleDict(projs)
text_proj_config = configs['text_proj']
if text_proj_config is None:
self.text_proj = nn.Identity()
else:
self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config)
def forward(self, multiscales, text_dict):
ret = {}
for scale_name, scale_feat in multiscales.items(): # b c t h w
if scale_name in self.video_multiscale_projs:
batch_size, _, nf = scale_feat.shape[:3]
scale_feat = rearrange(scale_feat, 'b c t h w -> (b t) c h w')
scale_feat = self.video_multiscale_projs[scale_name](scale_feat)
scale_feat = rearrange(scale_feat, '(b t) c h w -> b c t h w', b=batch_size, t=nf)
ret[scale_name] = scale_feat
else:
ret[scale_name] = scale_feat
if isinstance(text_dict, AMRData):
text_dict.amr_feats = self.text_proj(text_dict.amr_feats)
text_dict.text_feats = self.text_proj(text_dict.text_feats)
else:
raise ValueError()
return ret, text_dict
@META_ARCH_REGISTRY.register()
class ImageConv_MultiscaleProj(nn.Module):
def __init__(self,
configs,
out_dim,
multiscale_shapes=None,
) -> None:
"""
如果multiscale_shape是空, 那么输入的dim = out_dim
"""
super().__init__()
projs_configs = configs['projs']
proj_names = list(projs_configs.keys()) # list[str]
in_dims = {}
if multiscale_shapes is not None:
assert set(proj_names).issubset(set(list(multiscale_shapes.keys())))
for name in proj_names:
in_dims[name] = multiscale_shapes[name].dim
else:
for name in proj_names:
in_dims[name] = out_dim
projs = {}
for name, config in projs_configs.items():
projs[name] = Conv2d_NormAct(in_channels=in_dims[name], out_channels=out_dim,
**config)
self.multiscale_projs = nn.ModuleDict(projs)
def forward(self, multiscales):
ret = {}
for scale_name, scale_feat in multiscales.items():
if scale_name in self.multiscale_projs:
scale_feat = self.multiscale_projs[scale_name](scale_feat)
ret[scale_name] = scale_feat
else:
ret[scale_name] = scale_feat
return ret
@META_ARCH_REGISTRY.register()
class Video2D_ImageConv_MultiscaleProj(nn.Module):
def __init__(self,
configs,
out_dim,
multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
self.image_homo = ImageConv_MultiscaleProj(configs=configs, out_dim=out_dim, multiscale_shapes=multiscale_shapes)
def forward(self, multiscales):
batch_sisze, _, nf = multiscales[list(multiscales.keys())[0]].shape[:3]
# b c t h w -> bt c h w
multiscales = {key: value.permute(0, 2, 1, 3, 4).flatten(0, 1).contiguous() for key,value in multiscales.items()}
multiscales = self.image_homo(multiscales)
multiscales = {key: rearrange(value, '(b t) c h w -> b c t h w',b=batch_sisze, t=nf).contiguous()\
for key,value in multiscales.items()}
return multiscales
@META_ARCH_REGISTRY.register()
class VideoConv_MultiscaleProj(nn.Module):
def __init__(self,
configs,
out_dim,
multiscale_shapes=None,
) -> None:
"""
如果multiscale_shape是空, 那么输入的dim = out_dim
"""
super().__init__()
projs_configs = configs['projs']
proj_names = list(projs_configs.keys()) # list[str]
in_dims = {}
if multiscale_shapes is not None:
assert set(proj_names).issubset(set(list(multiscale_shapes.keys())))
for name in proj_names:
in_dims[name] = multiscale_shapes[name].dim
else:
for name in proj_names:
in_dims[name] = out_dim
projs = {}
for name, config in projs_configs.items():
projs[name] = Conv3d_NormAct(in_channels=in_dims[name], out_channels=out_dim,
**config)
self.multiscale_projs = nn.ModuleDict(projs)
def forward(self, multiscales):
ret = {}
for scale_name, scale_feat in multiscales.items():
if scale_name in self.multiscale_projs:
scale_feat = self.multiscale_projs[scale_name](scale_feat)
ret[scale_name] = scale_feat
else:
ret[scale_name] = scale_feat
return ret
@META_ARCH_REGISTRY.register()
class FrameQueryLinear_TextLinear(nn.Module):
def __init__(self,
configs,
out_dim,
text_dim=None, # int
query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
query_proj_config = configs['query_proj']
query_dim = out_dim if query_dim is None else query_dim
text_dim = out_dim if text_dim is None else text_dim
self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config)
text_proj_config = configs['text_proj']
if text_proj_config is None:
self.text_proj = nn.Identity()
else:
self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config)
def forward(self, frame_query, text_dict):
# b T nqf c
# text_dict
frame_query = self.query_proj(frame_query)
if isinstance(text_dict, AMRData):
text_dict.amr_feats = self.text_proj(text_dict.amr_feats)
text_dict.text_feats = self.text_proj(text_dict.text_feats)
else:
raise ValueError()
return frame_query, text_dict
@META_ARCH_REGISTRY.register()
class VideoConv3d_FrameQueryLinear_TextLinear(nn.Module):
def __init__(self,
configs,
out_dim,
feat_dim=None,
text_dim=None, # int
query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
query_proj_config = configs['query_proj']
feat_proj_config = configs['feat_proj']
text_proj_config = configs['text_proj']
feat_dim = out_dim if feat_dim is None else feat_dim
query_dim = out_dim if query_dim is None else query_dim
text_dim = out_dim if text_dim is None else text_dim
self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config) if query_proj_config is not None else nn.Identity()
self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config) if text_proj_config is not None else nn.Identity()
self.feat_proj = Conv3d_NormAct(in_channels=feat_dim, out_channels=out_dim, **feat_proj_config)
def forward(self, mask_feat, frame_query, text_dict):
mask_feat = self.feat_proj(mask_feat)
frame_query = self.query_proj(frame_query)
if isinstance(text_dict, AMRData):
text_dict.amr_feats = self.text_proj(text_dict.amr_feats)
text_dict.text_feats = self.text_proj(text_dict.text_feats)
else:
raise ValueError()
return mask_feat, frame_query, text_dict
# 每一个module应该都把input进行一边proj, proj到自己的空间里
@META_ARCH_REGISTRY.register()
class VideoConv3d_FrameQueryLinear(nn.Module):
"""
如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim
如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim
"""
def __init__(self,
configs,
out_dim,
query_dim=None, # 如果是none的话, 那么假设等于out_dim
multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
query_dim = out_dim if query_dim is None else query_dim
multiscale_projs_config = configs['video_multiscale_projs']
proj_names = multiscale_projs_config.keys() # list[str]
in_dims = {}
if multiscale_shapes is not None:
assert set(proj_names).issubset(set(list(multiscale_shapes.keys())))
for name in proj_names:
in_dims[name] = multiscale_shapes[name].dim
else:
for name in proj_names:
in_dims[name] = out_dim
projs = {}
for name, config in multiscale_projs_config.items():
projs[name] = Conv3d_NormAct(in_channels=in_dims[name],
out_channels=out_dim,
**config)
self.video_multiscale_projs = nn.ModuleDict(projs)
query_proj_config = configs['query_proj']
if query_proj_config is None:
self.query_proj = nn.Identity()
else:
self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config)
def forward(self, multiscales, frame_queries):
# b t nq c
ret = {}
for scale_name, scale_feat in multiscales.items(): # b c t h w
if scale_name in self.video_multiscale_projs:
scale_feat = self.video_multiscale_projs[scale_name](scale_feat)
ret[scale_name] = scale_feat
else:
ret[scale_name] = scale_feat
frame_queries = self.query_proj(frame_queries)
return ret, frame_queries
@META_ARCH_REGISTRY.register()
class FrameQueryLinear(nn.Module):
def __init__(self,
configs,
out_dim,
query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale])
) -> None:
super().__init__()
query_proj_config = configs['query_proj']
query_dim = out_dim if query_dim is None else query_dim
self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config)
def forward(self, frame_query):
# b T nqf c
frame_query = self.query_proj(frame_query)
return frame_query
================================================
FILE: models/encoder/localGlobal.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import numpy as np
from typing import Callable, Dict, List, Optional, Tuple, Union
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
from torch.cuda.amp import autocast
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import META_ARCH_REGISTRY
from models.layers.position_encoding import PositionEmbeddingSine
from models.layers.utils import _get_clones, _get_activation_fn
from .ops.modules import MSDeformAttn
# MSDeformAttn Transformer encoder in deformable detr
class MSDeformAttnTransformerEncoderOnly(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
activation="relu",
num_feature_levels=4, enc_n_points=4,
add_local = False,
add_global=False,
local_configs=None,
global_configs=None,
frame_nqueries=None,
):
super().__init__()
self.d_model = d_model
self.nhead = nhead
encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model = d_model,
d_ffn = dim_feedforward,
dropout = dropout,
activation = activation,
n_levels = num_feature_levels,
n_heads = nhead,
n_points = enc_n_points,
add_local = add_local,
add_global = add_global,
local_configs = local_configs,
global_configs = global_configs
)
self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers,
d_model=d_model,
frame_nqueries=frame_nqueries)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape # b h w
valid_H = torch.sum(~mask[:, :, 0], 1) # b
valid_W = torch.sum(~mask[:, 0, :], 1) # b
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) # b 2
return valid_ratio
def forward(self,
srcs=None,
pos_embeds=None,
video_aux_dict=None,
**kwargs):
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # b #scale 2
# encoder
memory, frame_feats, frame_poses = self.encoder(src=src_flatten, # bt hw_sigma c
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
pos=lvl_pos_embed_flatten,
padding_mask=mask_flatten,
video_aux_dict=video_aux_dict)
return memory, spatial_shapes, level_start_index, frame_feats, frame_poses
class MSDeformAttnTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4,
add_local=False,
add_global=False,
local_configs=None,
global_configs=None):
super().__init__()
# deform2d
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.add_local = add_local
if self.add_local:
from .neighborhood_qk import NA_qk_Layer
# self
self.local_cnp = NA_qk_Layer(d_model=d_model, configs=local_configs)
self.add_global = add_global
if self.add_global:
from models.layers.decoder_layers import CrossAttentionLayer
# cross
self.frame_query_cross_multiscale = CrossAttentionLayer(d_model=d_model, nhead=8, dropout=0.0,
activation="relu", normalize_before=False)
self.cross_num_heads = 8
self.global_add_attn_mask = global_configs['add_attn_mask'] if 'add_attn_mask' in global_configs else False
# self+ffn
from models.encoder.ops.modules.frame_query_ss2d import FrameQuery_SS2DLayer_hilbert
self.global_hiss = FrameQuery_SS2DLayer_hilbert(global_configs)
self.multiscale_cross_query = CrossAttentionLayer(d_model=d_model, nhead=8, dropout=0.0,
activation="relu", normalize_before=False)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
@torch.no_grad()
def get_attn_mask(self, frame_query_feats, src, spatial_shapes, level_start_index,):
# nq bt c
# bt hw_sigma c
assert len(spatial_shapes) == 3
frame_query_feats = frame_query_feats.permute(1, 0, 2) # bt nq c
feat = src[:, level_start_index[-1]: (level_start_index[-1] + spatial_shapes[-1][0] * spatial_shapes[-1][1])]
feat = rearrange(feat, 'b (h w) c -> b c h w',h=spatial_shapes[-1][0],w=spatial_shapes[-1][1])
mask = torch.einsum('bnc, bchw -> b n h w',frame_query_feats, feat)
mask_2 = F.interpolate(mask, size=spatial_shapes[0].tolist(), mode='bilinear',align_corners=False)
mask_3 = F.interpolate(mask, size=spatial_shapes[1].tolist(), mode='bilinear', align_corners=False)
attn_mask = torch.cat([mask_2.flatten(2), mask_3.flatten(2), mask.flatten(2)], dim=-1) #bt n hw_sigma
attn_mask = (attn_mask.unsqueeze(1).repeat(1, self.cross_num_heads, 1, 1).flatten(0, 1).sigmoid() < 0.5).bool()
return attn_mask
def forward(self,
src=None, pos=None,
reference_points=None, spatial_shapes=None, level_start_index=None, padding_mask=None,
video_aux_dict=None,
frame_query_feats=None, # nq bt c
frame_query_poses=None):
if self.add_local:
# local_self
src = self.local_cnp(tgt=src,
scale_shapes=spatial_shapes,
level_start_idxs=level_start_index,
nf=video_aux_dict['nf'])
if self.add_global:
if self.global_add_attn_mask:
attn_mask = self.get_attn_mask(frame_query_feats, src, spatial_shapes, level_start_index,) # bthead nq hw
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 全masked掉的 全注意, 比如有padding
else:
attn_mask = None
# cross
frame_query_feats = self.frame_query_cross_multiscale(
tgt=frame_query_feats, # nq bt c
memory=src.permute(1, 0, 2), # hw_sigma bt c
memory_mask=attn_mask,
memory_key_padding_mask=None,
pos= pos.permute(1,0,2),
query_pos=frame_query_poses,
)
# self+ffn
frame_query_feats = self.global_hiss(frame_query_feats=frame_query_feats,
frame_query_poses=frame_query_poses,
video_aux_dict=video_aux_dict)
# self
src = self.multiscale_cross_query(
tgt=src.permute(1, 0, 2), # hw_sigma bt c
memory=frame_query_feats, # nq bt c
memory_mask=None,
memory_key_padding_mask=None,
pos= frame_query_poses,
query_pos=pos.permute(1,0,2),
).permute(1, 0, 2)
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src, frame_query_feats
class MSDeformAttnTransformerEncoder(nn.Module):
def __init__(self,
encoder_layer=None,
num_layers=None,
d_model=None, frame_nqueries=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.frame_nqueries = frame_nqueries # 10
self.frame_query_feats = nn.Embedding(self.frame_nqueries, d_model)
self.frame_query_poses = nn.Embedding(self.frame_nqueries, d_model)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
# b #scale 2, valid_w(0-1), valid_h(0-1), 整个feature map有多少是非padding的
# list[h w] #scale
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) # 1 hw / b 1 -> b hw(0-1), y的绝对坐标
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) # 1 hw / b 1 -> b hw(0-1), x的绝对坐标
ref = torch.stack((ref_x, ref_y), -1) # b hw 2
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1) # b hw_sigma 2, 每个点的相对坐标(0-1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None] # b hw_sigma 1 2 * b 1 #scale 2
return reference_points # b hw_sigma #scale 2
def forward(self,
src, # bt hw_sigma c
spatial_shapes,
level_start_index,
valid_ratios,
pos=None,
padding_mask=None,
video_aux_dict=None):
output = src # bt hw_sigma c
batch_size_nf = output.shape[0]
frame_query_feats = self.frame_query_feats.weight.unsqueeze(1).repeat(1,batch_size_nf, 1)
frame_query_poses = self.frame_query_poses.weight.unsqueeze(1).repeat(1,batch_size_nf,1) # n bt c
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
frame_feats = []
for _, layer in enumerate(self.layers):
output, frame_query_feats = layer(src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
padding_mask=padding_mask,
video_aux_dict=video_aux_dict,
frame_query_feats=frame_query_feats,
frame_query_poses=frame_query_poses)
frame_feats.append(frame_query_feats)
return output, frame_feats, frame_query_poses
import copy
from einops import rearrange
from models.layers.utils import _get_clones
from models.layers.position_encoding import build_position_encoding
# video multiscale, text_dict
@META_ARCH_REGISTRY.register()
class Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal(nn.Module):
def __init__(
self,
configs,
multiscale_shapes, # {'res2': .temporal_stride, .spatial_stride, .dim}
):
super().__init__()
d_model = configs['d_model']
fpn_norm = configs['fpn_norm'] # fpn的norm
nlayers = configs['nlayers']
# 4, 8, 16, 32
self.multiscale_shapes = dict(sorted(copy.deepcopy(multiscale_shapes).items(), key=lambda x: x[1].spatial_stride))
self.encoded_scales = sorted(configs['encoded_scales'],
key=lambda x:self.multiscale_shapes[x].spatial_stride) # res3, res4, res5
# 4 -> 8 -> 16 -> 32
self.scale_dims = [val.dim for val in multiscale_shapes.values()]
self.video_projs = META_ARCH_REGISTRY.get(configs['video_projs']['name'])(configs=configs['video_projs'],
multiscale_shapes=multiscale_shapes, out_dim=d_model)
self.pos_2d = build_position_encoding(position_embedding_name='2d')
deform_attn = configs['deform_attn']
self.transformer = MSDeformAttnTransformerEncoderOnly(
d_model=d_model,
dropout=deform_attn['dropout'],
nhead=deform_attn['nheads'],
dim_feedforward=deform_attn['dim_feedforward'],
activation=deform_attn['activation'],
num_encoder_layers=nlayers,
num_feature_levels=len(self.encoded_scales),
enc_n_points=deform_attn['enc_n_points'],
add_local = configs['add_local'],
add_global = configs['add_global'],
local_configs = configs['local_configs'],
global_configs = configs['global_configs'],
frame_nqueries=configs['frame_nqueries']
)
min_encode_stride = self.multiscale_shapes[self.encoded_scales[0]].spatial_stride # 8
min_stride = list(self.multiscale_shapes.values())[0].spatial_stride # 4
self.num_fpn_levels = int(np.log2(min_encode_stride) - np.log2(min_stride))
lateral_convs = []
output_convs = []
use_bias = fpn_norm == ""
for idx, in_channels in enumerate(self.scale_dims[:self.num_fpn_levels]):
lateral_norm = get_norm(fpn_norm, d_model)
output_norm = get_norm(fpn_norm, d_model)
lateral_conv = Conv2d(in_channels, d_model, kernel_size=1, bias=use_bias, norm=lateral_norm)
output_conv = Conv2d(d_model, d_model, kernel_size=3, padding=1, bias=use_bias, norm=output_norm, activation=F.relu)
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
self.add_module("layer_{}".format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1] # 8 4
self.output_convs = output_convs[::-1] # 8 4
def forward(self,
multiscales=None, # b c t h w
video_aux_dict=None, # dict{}
**kwargs):
batch_size, _, nf = multiscales[list(multiscales.keys())[0]].shape[:3]
video_aux_dict['nf'] = nf
multiscales = self.video_projs(multiscales)
assert set(list(multiscales.keys())).issubset(set(list(self.multiscale_shapes.keys())))
assert set(list(self.multiscale_shapes.keys())).issubset(set(list(multiscales.keys())))
srcs = []
poses = [] # 32, 16, 8
for idx, scale_name in enumerate(self.encoded_scales[::-1]):
x = multiscales[scale_name].permute(0, 2, 1, 3, 4).flatten(0,1).contiguous() # bt c h w
srcs.append(x)
poses.append(self.pos_2d(torch.zeros_like(x)[:, 0, :, :].bool(), hidden_dim=x.shape[1]))
memory, spatial_shapes, level_start_index, frame_feats, frame_poses = self.transformer(srcs=srcs,
pos_embeds=poses,
video_aux_dict=video_aux_dict)
bs = memory.shape[0]
spatial_index = 0
memory_features = [] # 32 16 8
for lvl in range(len(self.encoded_scales)):
h, w = spatial_shapes[lvl]
memory_lvl = memory[:, spatial_index : spatial_index + h * w, :].reshape(bs, h, w, -1).permute(0, 3, 1, 2).contiguous()
memory_features.append(memory_lvl)
spatial_index += h * w
for idx, f in enumerate(list(self.multiscale_shapes.keys())[:self.num_fpn_levels][::-1]):
x = multiscales[f].permute(0, 2, 1, 3, 4).flatten(0,1).contiguous() # bt c h w
cur_fpn = self.lateral_convs[idx](x)
y = cur_fpn + F.interpolate(memory_features[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
y = self.output_convs[idx](y)
memory_features.append(y)
assert len(memory_features) == len(list(self.multiscale_shapes.keys()))
ret = {}
for key, out_feat in zip(list(self.multiscale_shapes.keys()), memory_features[::-1]):
ret[key] = rearrange(out_feat, '(b t) c h w -> b c t h w', b=batch_size, t=nf)
return ret, frame_feats[::-1], frame_poses # 32, 16, 8
================================================
FILE: models/encoder/neighborhood_qk.py
================================================
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn.functional import pad
from torch.nn.init import trunc_normal_
from natten.functional import na2d_av, na2d_qk_with_bias
from einops import rearrange
from natten import NeighborhoodAttention2D
from detectron2.modeling import META_ARCH_REGISTRY
class NeighborhoodAttention2D_qk(nn.Module):
"""
Neighborhood Attention 2D Module
"""
def __init__(
self,
dim: int,
num_heads: int,
kernel_size: int,
dilation: int = 1,
bias: bool = True,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // self.num_heads
self.scale = qk_scale or self.head_dim**-0.5
assert (
kernel_size > 1 and kernel_size % 2 == 1
), f"Kernel size must be an odd number greater than 1, got {kernel_size}."
self.kernel_size = kernel_size
assert (
dilation is None or dilation >= 1
), f"Dilation must be greater than or equal to 1, got {dilation}."
self.dilation = dilation or 1
self.window_size = self.kernel_size * self.dilation
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
self.kv_linear = nn.Linear(dim, dim * 2, bias=qkv_bias)
if bias:
self.rpb = nn.Parameter(
torch.zeros(num_heads, (2 * kernel_size - 1), (2 * kernel_size - 1))
)
trunc_normal_(self.rpb, std=0.02, mean=0.0, a=-2.0, b=2.0)
else:
self.register_parameter("rpb", None)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self,
x_q: Tensor,
x_k: Tensor) -> Tensor:
# bt h w c; bt h w c, 前一帧
if x_q.dim() != 4:
raise ValueError(
f"NeighborhoodAttention2D expected a rank-4 input tensor; got {x.dim()=}."
)
B, H, W, C = x_q.shape
# Pad if the input is small than the minimum supported size
H_padded, W_padded = H, W
padding_h = padding_w = 0
if H < self.window_size or W < self.window_size:
padding_h = max(0, self.window_size - H_padded)
padding_w = max(0, self.window_size - W_padded)
x_q = pad(x_q, (0, 0, 0, padding_w, 0, padding_h))
x_k = pad(x_k, (0, 0, 0, padding_w, 0, padding_h))
_, H_padded, W_padded, _ = x_q.shape
# b h w c -> b h w h c_h
q = self.q_linear(x_q).reshape(B, H_padded, W_padded, self.num_heads, self.head_dim)
q = q.permute(0, 3, 1, 2, 4) # b head h w c_h
kv = (
self.kv_linear(x_k)
.reshape(B, H_padded, W_padded, 2, self.num_heads, self.head_dim)
.permute(3, 0, 4, 1, 2, 5)
) # b
k, v = kv[0], kv[1]
q = q * self.scale
attn = na2d_qk_with_bias(q, k, self.rpb, self.kernel_size, self.dilation)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_q = na2d_av(attn, v, self.kernel_size, self.dilation) # b head h w c_h
x_q = x_q.permute(0, 2, 3, 1, 4).reshape(B, H_padded, W_padded, C) # b h w head c_h
# Remove padding, if added any
if padding_h or padding_w:
x_q = x_q[:, :H, :W, :].contiguous()
return self.proj_drop(self.proj(x_q))
def extra_repr(self) -> str:
return (
f"head_dim={self.head_dim}, num_heads={self.num_heads}, "
+ f"kernel_size={self.kernel_size}, dilation={self.dilation}, "
+ f"has_bias={self.rpb is not None}"
)
from models.layers.utils import _get_clones
class NA_qk_Layer(nn.Module):
def __init__(self, d_model, configs):
super().__init__()
self.self_attn = NeighborhoodAttention2D_qk(dim=configs['d_model'],
num_heads=configs['num_heads'],
kernel_size=configs['kernel_size'],
dilation=configs['dilation'],
bias=False,
qkv_bias=False,)
self.num_steps = configs['num_steps'] if 'num_steps' in configs else 1
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(configs['dropout'])
def forward(self, tgt=None, scale_shapes=None, level_start_idxs=None, nf=None):
# bt hw_sigma c -> list[b t h w c], 3
video_feats = [tgt[:, start_idx:(start_idx + haosen[0]*haosen[1])].contiguous() for start_idx, haosen in zip(level_start_idxs, scale_shapes)]
video_feats = [rearrange(haosen, '(b t) (h w) c -> b t h w c', t=nf, h=scale_shapes[idx][0], w=scale_shapes[idx][1]).contiguous() for idx, haosen in enumerate(video_feats)]
video_key_feats = []
for haosen in video_feats:
scale_feats = torch.stack([torch.roll(haosen, shifts=k, dims=1) for k in range(1, self.num_steps+1)], dim=0) # s b t h w c
video_key_feats.append(scale_feats.flatten(0, 2)) #sbt h w c
# sbt h w c
video_feats = [haosen.unsqueeze(0).repeat(self.num_steps, 1,1,1,1,1).flatten(0, 2) for haosen in video_feats]
local_feats = [] # list[sbt h w c]
for idx, (q_feat, k_feat) in enumerate(zip(video_feats, video_key_feats)):
local_feats.append(self.self_attn(q_feat, k_feat))
local_feats = [rearrange(haosen, '(s bt) h w c -> s bt h w c',s=self.num_steps) for haosen in local_feats]
local_feats = [haosen.sum(dim=0) for haosen in local_feats] # bt h w c
local_feats = torch.cat([haosen.flatten(1, 2) for haosen in local_feats], dim=1) # bt hw_sigma c
tgt = tgt + self.dropout(local_feats)
tgt = self.norm(tgt)
return tgt
@META_ARCH_REGISTRY.register()
class NA_qk_Layer_v2(nn.Module):
def __init__(self, configs):
super().__init__()
self.self_attn = NeighborhoodAttention2D_qk(dim=configs['d_model'],
num_heads=configs['num_heads'],
kernel_size=configs['kernel_size'],
dilation=configs['dilation'],
bias=False,
qkv_bias=False,)
def forward(self,
query=None,
spatial_shapes=None,
level_start_index=None,
video_aux_dict=None,):
# bt hw_sigma c -> list[b t h w c], 3
video_feat = [query[:, start_idx:(start_idx + haosen[0]*haosen[1])].contiguous() for start_idx, haosen in zip(level_start_index, spatial_shapes)]
video_feat = [rearrange(haosen, '(b t) (h w) c -> b t h w c',t=video_aux_dict['nf'], h=spatial_shapes[idx][0], w=spatial_shapes[idx][1]).contiguous() for idx, haosen in enumerate(video_feat)]
video_key_feats = [torch.roll(haosen, shifts=1, dims=1).contiguous() for haosen in video_feat]
local_feats = [] # list[bt h w c]
for idx, (q_feat, k_feat) in enumerate(zip(video_feat, video_key_feats)):
local_feats.append(self.self_attn(q_feat.flatten(0, 1), k_feat.flatten(0, 1)))
local_feats = torch.cat([haosen.flatten(1, 2) for haosen in local_feats], dim=1) # bt hw_sigma c
return local_feats, None
================================================
FILE: models/encoder/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO
================================================
Metadata-Version: 2.1
Name: MultiScaleDeformableAttention
Version: 1.0
Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention
Home-page: https://github.com/fundamentalvision/Deformable-DETR
Author: Weijie Su
================================================
FILE: models/encoder/ops/attention.py
================================================
from inspect import isfunction
import math
import torch
from torch.nn.init import xavier_uniform_, constant_
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from .functions import MSDeformAttnFunction
import warnings
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class DeformAttn(nn.Module):
def __init__(self,
d_model=256,
nheads=8,
npoints=4,
nlevels=4,
key_dim=None):
super().__init__()
query_dim = d_model
key_dim = d_model
head_dim = d_model // nheads
if d_model % nheads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, nheads))
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(head_dim):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = nheads * head_dim
key_dim = default(key_dim, query_dim)
self.value_proj = nn.Linear(key_dim, nheads * head_dim)
self.sampling_offsets = nn.Linear(query_dim, nheads * nlevels * npoints * 2)
self.attention_weights = nn.Linear(query_dim, nheads * nlevels * npoints)
self.output_proj = nn.Linear(nheads * head_dim, query_dim)
self.n_heads = nheads
self.n_levels = nlevels
self.head_dim = head_dim
self.n_points = npoints
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points,
input_flatten, input_spatial_shapes, input_level_start_index,input_padding_mask=None):
"""
multi-scale deformable attention, self attention if query == input_flatten
Input:
- query:
T(b n c)
- reference_points: center or reference boxes, normalized, [0, 1], including padding area,
add additional (w, h) to form reference boxes
T(b n level 2) or T(b n level 4)
- input_flatten: multi-scale特征
T(b (h_\sigma w_\sigma) c)
- input_spatial_shapes: 每个level的大小
T(level 2)
- input_level_start_index: [0, level1_start, level2_start]
- input_padding_mask: True/False
T(b), (h_\sigma w_\sigma))
Output:
- query results:
T(b, n c)
- sampling_locations: normalized
T(b, n, m*l*k, 2)
- attention_weights: after softmax
T(b, n, m*l*k)
"""
batch_size, Nq, _ = query.shape
_, Nk, _ = input_flatten.shape
assert Nk == (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum()
# B (h w) M * V
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[...,None], float(0))
value = value.view(batch_size, Nk, self.n_heads, self.head_dim)
sampling_offesets = self.sampling_offsets(query).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights= self.attention_weights(query).view(batch_size, Nq, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, dim=-1).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points)
# b, n ,head, level, point, 2
if reference_points.shape[-1] == 2:
# T(2 level)
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] + \
sampling_offesets / offset_normalizer[None, None, None, :, None,:]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] + \
sampling_offesets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise NotImplementedError
output = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output, sampling_locations, attention_weights
class ContextuallSelfAttention(nn.Module):
def __init__(self,
d_model,
n_points,
n_heads,
context_dim=None):
super().__init__()
context_dim = default(context_dim, d_model)
query_dim = key_dim = d_model
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
head_dim = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(head_dim):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.nheads = n_heads
self.head_dim = head_dim
self.npoints = n_points
self.nlevels = 1
self.value_proj = nn.Linear(key_dim, n_heads * head_dim)
self.sampling_offsets = nn.Linear(query_dim, n_heads * n_points * 2)
self.attention_weights = nn.Linear(query_dim, n_heads * n_points)
self.output_proj = nn.Linear(n_heads * head_dim, query_dim)
def forward(self,
context, context_mask,
query, reference_points, query_padding_mask = None,):
"""
contextual deformable attention
Input:
- context:
T(b n c)
- context_mask:
T(b n)
- query:
T(b (h w) c)
- reference_points: center or reference boxes, normalized, [0, 1], including padding area,
T(b (h w) 2/4)
- query_padding_mask:
T(b (h w))
Output:
- query results:
T(b, n c)
- sampling_locations: normalized
T(b, n, m*l*k, 2)
- attention_weights: after softmax
T(b, n, m*l*k)
"""
key = query
key_padding_mask = query_padding_mask
batch_size, Nq, _ = query.shape
Nk = Nq
input_spatial_shapes = torch.tensor(query.shape[-2:]).unsqueeze(0) # T(1, 2)
input_level_start_index = [0, ]
# B (h w) M * V
value = self.value_proj(key)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(batch_size, Nk, self.nheads, self.head_dim)
sampling_offesets = self.sampling_offsets(query).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights= self.attention_weights(query).view(batch_size, Nq, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, dim=-1).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points)
# b, n ,head, level, point, 2
if reference_points.shape[-1] == 2:
# T(2 level)
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] + \
sampling_offesets / offset_normalizer[None, None, None, :, None,:]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] + \
sampling_offesets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise NotImplementedError
output = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output, sampling_locations, attention_weights
class BasicTransformerBlock_v2(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
super().__init__()
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, context=None):
x = self.norm1(self.attn1(x, context=context) + x)
x = self.norm2(self.ff(x) + x)
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/ms_deform_attn_func.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/ms_deform_attn.py
================================================
# Modify for sample points visualization
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output, sampling_locations, attention_weights
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn
================================================
FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py
================================================
# Modify for sample points visualization
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output, sampling_locations, attention_weights
================================================
FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-311/.ninja_log
================================================
# ninja log v5
0 5344 1685604027 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.o 1eaabdd4515aceab
1 20910 1685604042 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/vision.o b8641c4a4f7766f9
0 21063 1685604042 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.o d77fcd8ae1c377bb
================================================
FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-311/build.ninja
================================================
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc
cflags = -pthread -B /home/xhh/anaconda3/envs/natten/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/xhh/anaconda3/envs/natten/include -fPIC -O2 -isystem /home/xhh/anaconda3/envs/natten/include -fPIC -DWITH_CUDA -I/home/xhh/workspace/rvos_encoder/models/ops/src -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/TH -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/xhh/anaconda3/envs/natten/include/python3.11 -c
post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
cuda_cflags = -DWITH_CUDA -I/home/xhh/workspace/rvos_encoder/models/ops/src -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/TH -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/xhh/anaconda3/envs/natten/include/python3.11 -c
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17
cuda_dlink_post_cflags =
ldflags =
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.o: compile /home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.cpp
build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile /home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.cu
build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/vision.o: compile /home/xhh/workspace/rvos_encoder/models/ops/src/vision.cpp
================================================
FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-38/home/xhh/workspace/ReferFormer/models/ops/src/vision.o
================================================
[File too large to display: 12.3 MB]
================================================
FILE: models/encoder/ops/functions/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction
================================================
FILE: models/encoder/ops/functions/ms_deform_attn_func.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
================================================
FILE: models/encoder/ops/make.sh
================================================
#!/usr/bin/env bash
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
python setup.py build install
================================================
FILE: models/encoder/ops/modules/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn
from . import frame_query_ss2d
================================================
FILE: models/encoder/ops/modules/frame_query_ss2d.py
================================================
from models.layers.position_encoding import build_position_encoding
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
import math
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction
from mamba_ssm import Mamba
from einops import rearrange, reduce, repeat
from detectron2.modeling import META_ARCH_REGISTRY
# v1
class SS2D(nn.Module):
def __init__(
self,
d_model,
d_state=16,
# d_state="auto", # 20240109
d_conv=3,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
# self.selective_scan = selective_scan_fn
self.forward_core = self.forward_corev0
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_corev0(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
def forward(self, x: torch.Tensor, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y1, y2, y3, y4 = self.forward_core(x) # B C hw
assert y1.dtype == torch.float32
y = y1 + y2 + y3 + y4
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class SS2D_FrameQuery(nn.Module):
def __init__(self, configs,):
super().__init__()
d_model = configs['d_model']
self.homo = SS2D(d_model=configs['d_model'],
d_state=configs['d_state'] if 'd_state' in configs else 16,
d_conv=configs['d_conv'] if 'd_conv' in configs else 3,
expand=configs['expand'] if 'expand' in configs else 2,
dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto',
dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001,
dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1,
dt_init=configs['dt_init'] if 'dt_init' in configs else 'random',
dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0,
dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4,
dropout=configs['dropout'] if 'dropout' in configs else 0,
conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True,
bias=configs['bias'] if 'bias' in configs else False,
)
self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding
def forward(self,
frame_query_feats=None, # n bt c
frame_query_poses=None, # n bt c # nq上的Position embedding
nf=None,
**kwargs
):
batch_size = frame_query_feats.shape[1] // nf # b
frame_query_feats += frame_query_poses
frame_query_feats = rearrange(frame_query_feats, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous()
sin_poses = self.pos_1d(torch.zeros_like(frame_query_feats[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(),
hidden_dim=frame_query_feats.shape[-1]) # bn c t
sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size)
frame_query_feats += sin_poses
frame_query_feats = self.homo(frame_query_feats) # b t n c
frame_query_feats = frame_query_feats.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c
return frame_query_feats, None
@META_ARCH_REGISTRY.register()
class FrameQuery_SS2DLayer(nn.Module):
def __init__(self,
configs,
dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
d_model = configs['d_model']
dropout = configs['dropout']
self.self_attn = SS2D_FrameQuery(configs)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
from models.layers.decoder_layers import FFNLayer
self.ffn = FFNLayer(d_model=d_model,
dim_feedforward=configs['dim_feedforward'],
dropout=dropout,)
def forward(self,
frame_query_feats, # n bt c
frame_query_poses, # n bt c # nq上的Position embedding
nf=None,
**kwargs):
tgt2 = self.self_attn(frame_query_feats=frame_query_feats, # n bt c
frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding
nf=nf,)[0]
frame_query_feats = frame_query_feats + self.dropout(tgt2)
frame_query_feats = self.norm(frame_query_feats)
frame_query_feats = self.ffn(frame_query_feats)
return frame_query_feats
from models.layers.decoder_layers import CrossAttentionLayer, SelfAttentionLayer, FFNLayer
@META_ARCH_REGISTRY.register()
class TemporalQuery_CrossSelf(nn.Module):
def __init__(self, configs) -> None:
super().__init__()
d_model = configs['d_model']
attn_configs = configs['attn']
self.cross_layers = CrossAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before'])
self.self_layers = SelfAttentionLayer(d_model=d_model,
nhead=attn_configs['nheads'],
dropout=0.0,
normalize_before=attn_configs['normalize_before'])
self.ffn_layers = FFNLayer(d_model=d_model,
dim_feedforward=attn_configs['dim_feedforward'],
dropout=0.0,
normalize_before=attn_configs['normalize_before'])
def forward(self,
temporal_query_feats,
temporal_query_poses,
frame_query_feats, frame_query_poses,
video_aux_dict=None, **kwargs):
# nq b c; nq bt c
nq, batch_size, _ = temporal_query_feats.shape
nf = frame_query_feats.shape[1] // batch_size
nqf = frame_query_feats.shape[0]
frame_query_feats = rearrange(frame_query_feats, 'nq (b t) c -> (t nq) b c',b=batch_size, t=nf)
frame_query_poses = rearrange(frame_query_poses, 'nq (b t) c -> (t nq) b c',b=batch_size, t=nf)
temporal_query_feats = self.cross_layers(
tgt=temporal_query_feats, # n b c
memory=frame_query_feats, # t nqf b c
pos=frame_query_poses,
query_pos=temporal_query_poses,
)
temporal_query_feats = self.self_layers(
temporal_query_feats,
query_pos=temporal_query_poses,
)
temporal_query_feats = self.ffn_layers(
temporal_query_feats
)
return temporal_query_feats
# v2 多层
class SS2D_FrameQuery_v2(nn.Module):
def __init__(self, configs,):
super().__init__()
d_model = configs['d_model']
self.homo = SS2D(d_model=configs['d_model'],
d_state=configs['d_state'] if 'd_state' in configs else 16,
d_conv=configs['d_conv'] if 'd_conv' in configs else 3,
expand=configs['expand'] if 'expand' in configs else 2,
dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto',
dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001,
dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1,
dt_init=configs['dt_init'] if 'dt_init' in configs else 'random',
dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0,
dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4,
dropout=configs['dropout'] if 'dropout' in configs else 0,
conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True,
bias=configs['bias'] if 'bias' in configs else False,
)
self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(configs['dropout'])
def forward(self,
frame_query_feats=None, # n bt c
frame_query_poses=None, # n bt c # nq上的Position embedding
nf=None,
**kwargs
):
batch_size = frame_query_feats.shape[1] // nf # b
tgt2 = frame_query_feats + frame_query_poses
tgt2 = rearrange(tgt2, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous()
sin_poses = self.pos_1d(torch.zeros_like(tgt2[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(),
hidden_dim=tgt2.shape[-1]) # bn c t
sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size)
tgt2 += sin_poses
tgt2 = self.homo(tgt2) # b t n c
tgt2 = tgt2.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c
frame_query_feats = frame_query_feats + self.dropout(tgt2)
frame_query_feats = self.norm(frame_query_feats)
return frame_query_feats, None
from models.layers.utils import _get_clones
@META_ARCH_REGISTRY.register()
class FrameQuery_SS2DLayer_v2(nn.Module):
def __init__(self,
configs,
dropout=0.0):
super().__init__()
d_model = configs['d_model']
n_layers = configs['nlayers'] if 'nlayers' in configs else 1
self.nlayers = n_layers
self.self_attn = _get_clones(SS2D_FrameQuery_v2(configs), n_layers)
from models.layers.decoder_layers import FFNLayer
self.ffn = FFNLayer(d_model=d_model,
dim_feedforward=configs['dim_feedforward'],
dropout=configs['dropout'],)
def forward(self,
frame_query_feats, # n bt c
frame_query_poses, # n bt c # nq上的Position embedding
nf=None,
**kwargs):
for i in range(self.nlayers):
frame_query_feats = self.self_attn[i](frame_query_feats=frame_query_feats, # n bt c
frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding
nf=nf,)[0]
frame_query_feats = self.ffn(frame_query_feats)
return frame_query_feats
class Hilbert_2DSelectiveScan(nn.Module):
def __init__(
self,
d_model,
d_state=16,
# d_state="auto", # 20240109
d_conv=3,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
scan_order=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=2, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=2, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=2, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=2, merge=True) # (K=2, D, N)
self.Ds = self.D_init(self.d_inner, copies=2, merge=True) # (K=2, D, N)
# self.selective_scan = selective_scan_fn
self.forward_core = self.forward_corev0
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
self.scan_order = scan_order
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_corev0(self, x: torch.Tensor, hilbert_curve):
# LongTensor[int] 按照hw进行flatten之后的hilbert排序
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 2
if self.scan_order == 'zigzag':
x_hw = x.view(B, -1, L).contiguous() # b c hw
xs = torch.stack([x_hw, torch.flip(x_hw, dims=[-1])], dim=1) # (b, k, d, l)
elif self.scan_order == 'hilbert':
x_hw = x.flatten(2).contiguous() # b c hw
x_hil = x_hw.index_select(dim=-1, index=hilbert_curve)
xs = torch.stack([x_hil, torch.flip(x_hil, dims=[-1])], dim=1) # (b, k, d, l)
else:
raise ValueError()
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
if self.scan_order == 'zigzag':
hw_order = out_y[:, 0].contiguous().view(B, -1, H, W).contiguous()
rhw_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous()
rhw_order = rhw_order.view(B, -1, H, W,).contiguous()
return hw_order + rhw_order
elif self.scan_order == 'hilbert':
hil_order = out_y[:, 0].contiguous() # b c hw
rhil_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() # b c hw
sum_out = torch.zeros_like(hil_order)
hilbert_curve = repeat(hilbert_curve, 'hw -> b c hw', b=hil_order.shape[0], c=hil_order.shape[1])
assert hil_order.shape == hilbert_curve.shape
sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=hil_order)
sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=rhil_order)
sum_out = sum_out.view(B, -1, H, W).contiguous()
return sum_out
# def forward_corev0(self, x: torch.Tensor, hilbert_curve):
# # LongTensor[int] 按照hw进行flatten之后的hilbert排序
# self.selective_scan = selective_scan_fn
# B, C, H, W, T = x.shape
# L = H * W * T
# K = 2
# if self.scan_order == 'zigzag':
# x_hw = x.view(B, -1, L).contiguous() # b c hwt
# xs = torch.stack([x_hw, torch.flip(x_hw, dims=[-1])], dim=1) # (b, k, d, l)
# elif self.scan_order == 'hilbert':
# x_hw = x.flatten(2).contiguous() # b c hwt
# x_hil = x_hw.index_select(dim=-1, index=hilbert_curve)
# xs = torch.stack([x_hil, torch.flip(x_hil, dims=[-1])], dim=1) # (b, k, d, l)
# else:
# raise ValueError()
# x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
# dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
# dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
# xs = xs.float().view(B, -1, L) # (b, k * d, l)
# dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
# Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
# Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
# Ds = self.Ds.float().view(-1) # (k * d)
# As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
# dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
# out_y = self.selective_scan(
# xs, dts,
# As, Bs, Cs, Ds, z=None,
# delta_bias=dt_projs_bias,
# delta_softplus=True,
# return_last_state=False,
# ).view(B, K, -1, L)
# assert out_y.dtype == torch.float
# if self.scan_order == 'zigzag':
# hw_order = out_y[:, 0].contiguous().view(B, -1, H, W).contiguous()
# rhw_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous()
# rhw_order = rhw_order.view(B, -1, H, W,).contiguous()
# return hw_order + rhw_order
# elif self.scan_order == 'hilbert':
# hil_order = out_y[:, 0].contiguous() # b c hw
# rhil_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() # b c hw
# sum_out = torch.zeros_like(hil_order)
# hilbert_curve = repeat(hilbert_curve, 'hwt -> b c hwt', b=hil_order.shape[0], c=hil_order.shape[1])
# assert hil_order.shape == hilbert_curve.shape
# sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=hil_order)
# sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=rhil_order)
# sum_out = sum_out.view(B, -1, H, W).contiguous()
# return sum_out
def forward(self, x: torch.Tensor, hilbert_curve, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y = self.forward_core(x, hilbert_curve=hilbert_curve) # B C h w
y = y.permute(0, 2, 3, 1).contiguous() # b h w c
y = self.out_norm(y)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class SS2D_FrameQuery_hilbert(nn.Module):
def __init__(self, configs,):
super().__init__()
d_model = configs['d_model']
self.homo = Hilbert_2DSelectiveScan(d_model=configs['d_model'],
d_state=configs['d_state'] if 'd_state' in configs else 16,
d_conv=configs['d_conv'] if 'd_conv' in configs else 3,
expand=configs['expand'] if 'expand' in configs else 2,
dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto',
dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001,
dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1,
dt_init=configs['dt_init'] if 'dt_init' in configs else 'random',
dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0,
dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4,
dropout=configs['dropout'] if 'dropout' in configs else 0,
conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True,
bias=configs['bias'] if 'bias' in configs else False,
scan_order=configs['scan_order']
)
self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(configs['dropout'])
def forward(self,
frame_query_feats=None, # n bt c
frame_query_poses=None, # n bt c # nq上的Position embedding
hilbert_curve=None,
nf=None,
**kwargs
):
batch_size = frame_query_feats.shape[1] // nf # b
tgt2 = frame_query_feats + frame_query_poses
tgt2 = rearrange(tgt2, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous()
sin_poses = self.pos_1d(torch.zeros_like(tgt2[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(),
hidden_dim=tgt2.shape[-1]) # bn c t
sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size)
tgt2 += sin_poses
tgt2 = self.homo(tgt2, hilbert_curve=hilbert_curve) # b t n c
tgt2 = tgt2.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c
frame_query_feats = frame_query_feats + self.dropout(tgt2)
frame_query_feats = self.norm(frame_query_feats)
return frame_query_feats, None
from models.layers.utils import _get_clones
@META_ARCH_REGISTRY.register()
class FrameQuery_SS2DLayer_hilbert(nn.Module):
def __init__(self,
configs,
dropout=0.0):
super().__init__()
d_model = configs['d_model']
n_layers = configs['nlayers'] if 'nlayers' in configs else 1
self.nlayers = n_layers
self.self_attn = _get_clones(SS2D_FrameQuery_hilbert(configs), n_layers)
from models.layers.decoder_layers import FFNLayer
self.ffn = FFNLayer(d_model=d_model,
dim_feedforward=configs['dim_feedforward'],
dropout=configs['dropout'],)
def forward(self,
frame_query_feats, # n bt c
frame_query_poses, # n bt c # nq上的Position embedding
video_aux_dict=None,
**kwargs):
for i in range(self.nlayers):
frame_query_feats = self.self_attn[i](frame_query_feats=frame_query_feats, # n bt c
frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding
nf=video_aux_dict['nf'],
hilbert_curve=video_aux_dict['hilbert_curve'])[0]
frame_query_feats = self.ffn(frame_query_feats)
return frame_query_feats
================================================
FILE: models/encoder/ops/modules/ms_deform_attn.py
================================================
# Modify for sample points visualization
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 512
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output, sampling_locations, attention_weights
================================================
FILE: models/encoder/ops/setup.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import os
import glob
import torch
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
raise NotImplementedError('Cuda is not availabel')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"MultiScaleDeformableAttention",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="MultiScaleDeformableAttention",
version="1.0",
author="Weijie Su",
url="https://github.com/fundamentalvision/Deformable-DETR",
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
packages=find_packages(exclude=("configs", "tests",)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
================================================
FILE: models/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include
#include
#include
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
================================================
FILE: models/encoder/ops/src/cpu/ms_deform_attn_cpu.h
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
================================================
FILE: models/encoder/ops/src/cuda/ms_deform_attn_cuda.cu
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include
#include "cuda/ms_deform_im2col_cuda.cuh"
#include
#include
#include
#include
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data() + n * im2col_step_ * per_value_size,
spatial_shapes.data(),
level_start_index.data(),
sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data(),
value.data() + n * im2col_step_ * per_value_size,
spatial_shapes.data(),
level_start_index.data(),
sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
================================================
FILE: models/encoder/ops/src/cuda/ms_deform_attn_cuda.h
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
================================================
FILE: models/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh
================================================
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include
#include
#include
#include
#include
#include
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template
void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel
<<>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template
void ms_deformable_col2im_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_gm
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
<<>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
================================================
FILE: models/encoder/ops/src/ms_deform_attn.h
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
================================================
FILE: models/encoder/ops/src/vision.cpp
================================================
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
================================================
FILE: models/encoder/ops/test.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H*W).item() for H, W in shapes])
torch.manual_seed(3)
@torch.no_grad()
def check_forward_equal_with_pytorch_double():
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad()
def check_forward_equal_with_pytorch_float():
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
func = MSDeformAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
print(f'* {gradok} check_gradient_numerical(D={channels})')
if __name__ == '__main__':
check_forward_equal_with_pytorch_double()
check_forward_equal_with_pytorch_float()
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
check_gradient_numerical(channels, True, True, True)
================================================
FILE: models/layers/anyc_trans.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange, reduce
from typing import Any, Optional
from torch import Tensor
from .utils import _get_activation_fn
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class Linear_NormAct(nn.Linear):
def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
out_features = kwargs['out_features']
if norm == None:
self.norm = None
elif norm == 'ln':
self.norm = nn.LayerNorm(out_features)
else:
raise ValueError()
self.activation = _get_activation_fn(activation)
def forward(self, x):
x = F.linear(x, self.weight, self.bias)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class Conv2d_NormAct(torch.nn.Conv2d):
def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
out_dim = kwargs['out_channels']
if norm is None:
self.norm = None
elif norm == 'bn2d':
# b c h w
self.norm = nn.BatchNorm2d(out_dim)
elif 'gn' in norm:
# b c ..
n_groups = int(norm.split('_')[-1])
self.norm = nn.GroupNorm(n_groups, out_dim)
else:
raise ValueError()
self.activation = _get_activation_fn(activation)
def forward(self, x):
# b c h w
x = self._conv_forward(x, self.weight, self.bias)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class Conv3d_NormAct(torch.nn.Conv3d):
def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
out_dim = kwargs['out_channels']
if norm == None:
self.norm = None
elif 'gn' in norm:
n_groups = int(norm.split('_')[-1])
self.norm = nn.GroupNorm(n_groups, out_dim)
else:
raise ValueError()
self.activation = _get_activation_fn(activation)
def forward(self, x):
x = self._conv_forward(x, self.weight, self.bias)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
================================================
FILE: models/layers/decoder_layers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange, reduce
from typing import Any, Optional
from torch import Tensor
from .utils import _get_activation_fn
class SelfAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,):
if self.normalize_before:
return self.forward_pre(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
return self.forward_post(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt) # n b d
return tgt
def forward_pre(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
================================================
FILE: models/layers/gilbert/demo/index.html
================================================
Gilbert Curve