Repository: lxtGH/CAE
Branch: master
Commit: d72597143e48
Files: 236
Total size: 1.2 MB
Directory structure:
gitextract_wj4p2u9n/
├── .gitignore
├── .gitignore.swp
├── README.md
├── dall_e/
│ ├── __init__.py
│ ├── decoder.py
│ ├── encoder.py
│ └── utils.py
├── downstream_tasks/
│ ├── detection/
│ │ ├── README.md
│ │ ├── evaluation/
│ │ │ └── object_detection/
│ │ │ ├── configs/
│ │ │ │ ├── _base_/
│ │ │ │ │ ├── datasets/
│ │ │ │ │ │ └── coco_instance.py
│ │ │ │ │ ├── default_runtime.py
│ │ │ │ │ ├── models/
│ │ │ │ │ │ ├── cascade_mask_rcnn_r50_fpn.py
│ │ │ │ │ │ ├── cascade_mask_rcnn_swin_fpn.py
│ │ │ │ │ │ ├── cascade_mask_rcnn_vit_fpn.py
│ │ │ │ │ │ ├── mask_rcnn_r50_fpn.py
│ │ │ │ │ │ └── mask_rcnn_vit_fpn.py
│ │ │ │ │ └── schedules/
│ │ │ │ │ └── schedule_1x.py
│ │ │ │ └── mask_rcnn/
│ │ │ │ ├── vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py
│ │ │ │ └── vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py
│ │ │ ├── mmcv_custom/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── checkpoint.py
│ │ │ │ ├── layer_decay_optimizer_constructor.py
│ │ │ │ ├── prepare_rpe.py
│ │ │ │ ├── register_backbone.py
│ │ │ │ └── runner/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── checkpoint.py
│ │ │ │ └── epoch_based_runner.py
│ │ │ ├── test.py
│ │ │ └── train.py
│ │ ├── loader.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── head.py
│ │ │ ├── swin_transformer.py
│ │ │ └── vision_transformer.py
│ │ ├── scripts/
│ │ │ ├── run_eval.sh
│ │ │ ├── run_train_maskrcnn_vit_base.sh
│ │ │ └── run_train_maskrcnn_vit_large.sh
│ │ └── utils.py
│ └── semantic_segmentation/
│ ├── README.md
│ ├── backbone/
│ │ ├── beit.py
│ │ ├── beit_fapn.py
│ │ ├── cae.py
│ │ ├── fapn.py
│ │ └── mae.py
│ ├── configs_local/
│ │ ├── _base_/
│ │ │ ├── datasets/
│ │ │ │ ├── ade20k.py
│ │ │ │ ├── ade20k_640x640.py
│ │ │ │ ├── chase_db1.py
│ │ │ │ ├── cityscapes.py
│ │ │ │ ├── cityscapes_769x769.py
│ │ │ │ ├── coco-stuff10k.py
│ │ │ │ ├── drive.py
│ │ │ │ ├── hrf.py
│ │ │ │ ├── pascal_context.py
│ │ │ │ ├── pascal_voc12.py
│ │ │ │ ├── pascal_voc12_aug.py
│ │ │ │ └── stare.py
│ │ │ ├── default_runtime.py
│ │ │ ├── models/
│ │ │ │ ├── ann_r50-d8.py
│ │ │ │ ├── apcnet_r50-d8.py
│ │ │ │ ├── ccnet_r50-d8.py
│ │ │ │ ├── cgnet.py
│ │ │ │ ├── danet_r50-d8.py
│ │ │ │ ├── deeplabv3_r50-d8.py
│ │ │ │ ├── deeplabv3_unet_s5-d16.py
│ │ │ │ ├── deeplabv3plus_r50-d8.py
│ │ │ │ ├── dmnet_r50-d8.py
│ │ │ │ ├── dnl_r50-d8.py
│ │ │ │ ├── emanet_r50-d8.py
│ │ │ │ ├── encnet_r50-d8.py
│ │ │ │ ├── fast_scnn.py
│ │ │ │ ├── fcn_hr18.py
│ │ │ │ ├── fcn_r50-d8.py
│ │ │ │ ├── fcn_unet_s5-d16.py
│ │ │ │ ├── fpn_r50.py
│ │ │ │ ├── gcnet_r50-d8.py
│ │ │ │ ├── lraspp_m-v3-d8.py
│ │ │ │ ├── nonlocal_r50-d8.py
│ │ │ │ ├── ocrnet_hr18.py
│ │ │ │ ├── ocrnet_r50-d8.py
│ │ │ │ ├── pointrend_r50.py
│ │ │ │ ├── psanet_r50-d8.py
│ │ │ │ ├── pspnet_r50-d8.py
│ │ │ │ ├── pspnet_unet_s5-d16.py
│ │ │ │ ├── upernet_cae.py
│ │ │ │ └── upernet_r50.py
│ │ │ └── schedules/
│ │ │ ├── schedule_160k.py
│ │ │ ├── schedule_20k.py
│ │ │ ├── schedule_320k.py
│ │ │ ├── schedule_40k.py
│ │ │ └── schedule_80k.py
│ │ ├── beit/
│ │ │ └── upernet_beit_base_12_512_slide_160k_ade20k_pt_4e-4.py
│ │ ├── cae/
│ │ │ └── upernet/
│ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py
│ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py
│ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py
│ │ │ └── upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py
│ │ └── mae/
│ │ └── upernet_mae_large_12_512_slide_160k_ade20k_pt_4e-4.py
│ ├── mmcv_custom/
│ │ ├── __init__.py
│ │ ├── apex_runner/
│ │ │ ├── __init__.py
│ │ │ ├── apex_iter_based_runner.py
│ │ │ ├── checkpoint.py
│ │ │ └── optimizer.py
│ │ ├── checkpoint.py
│ │ ├── checkpoint_beit.py
│ │ ├── layer_decay_optimizer_constructor.py
│ │ ├── resize_transform.py
│ │ └── train_api.py
│ ├── mmseg/
│ │ ├── __init__.py
│ │ ├── apis/
│ │ │ ├── __init__.py
│ │ │ ├── inference.py
│ │ │ ├── test.py
│ │ │ └── train.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── evaluation/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── class_names.py
│ │ │ │ ├── eval_hooks.py
│ │ │ │ └── metrics.py
│ │ │ ├── seg/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── builder.py
│ │ │ │ └── sampler/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_pixel_sampler.py
│ │ │ │ └── ohem_pixel_sampler.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ └── misc.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── ade.py
│ │ │ ├── builder.py
│ │ │ ├── chase_db1.py
│ │ │ ├── cityscapes.py
│ │ │ ├── coco_stuff.py
│ │ │ ├── custom.py
│ │ │ ├── dataset_wrappers.py
│ │ │ ├── drive.py
│ │ │ ├── hrf.py
│ │ │ ├── pascal_context.py
│ │ │ ├── pipelines/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── compose.py
│ │ │ │ ├── formating.py
│ │ │ │ ├── loading.py
│ │ │ │ ├── test_time_aug.py
│ │ │ │ └── transforms.py
│ │ │ ├── stare.py
│ │ │ └── voc.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── backbones/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cgnet.py
│ │ │ │ ├── fast_scnn.py
│ │ │ │ ├── hrnet.py
│ │ │ │ ├── mobilenet_v2.py
│ │ │ │ ├── mobilenet_v3.py
│ │ │ │ ├── resnest.py
│ │ │ │ ├── resnet.py
│ │ │ │ ├── resnext.py
│ │ │ │ └── unet.py
│ │ │ ├── builder.py
│ │ │ ├── decode_heads/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── ann_head.py
│ │ │ │ ├── apc_head.py
│ │ │ │ ├── aspp_head.py
│ │ │ │ ├── cascade_decode_head.py
│ │ │ │ ├── cc_head.py
│ │ │ │ ├── da_head.py
│ │ │ │ ├── decode_head.py
│ │ │ │ ├── dm_head.py
│ │ │ │ ├── dnl_head.py
│ │ │ │ ├── ema_head.py
│ │ │ │ ├── enc_head.py
│ │ │ │ ├── fcn_head.py
│ │ │ │ ├── fpn_head.py
│ │ │ │ ├── gc_head.py
│ │ │ │ ├── lraspp_head.py
│ │ │ │ ├── nl_head.py
│ │ │ │ ├── ocr_head.py
│ │ │ │ ├── point_head.py
│ │ │ │ ├── psa_head.py
│ │ │ │ ├── psp_head.py
│ │ │ │ ├── sep_aspp_head.py
│ │ │ │ ├── sep_fcn_head.py
│ │ │ │ └── uper_head.py
│ │ │ ├── losses/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── accuracy.py
│ │ │ │ ├── cross_entropy_loss.py
│ │ │ │ ├── lovasz_loss.py
│ │ │ │ └── utils.py
│ │ │ ├── necks/
│ │ │ │ ├── __init__.py
│ │ │ │ └── fpn.py
│ │ │ ├── segmentors/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── cascade_encoder_decoder.py
│ │ │ │ └── encoder_decoder.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── inverted_residual.py
│ │ │ ├── make_divisible.py
│ │ │ ├── res_layer.py
│ │ │ ├── se_layer.py
│ │ │ ├── self_attention_block.py
│ │ │ └── up_conv_block.py
│ │ ├── ops/
│ │ │ ├── __init__.py
│ │ │ ├── encoding.py
│ │ │ └── wrappers.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── collect_env.py
│ │ │ └── logger.py
│ │ └── version.py
│ └── tools/
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── test.py
│ └── train.py
├── furnace/
│ ├── dataset_folder.py
│ ├── datasets.py
│ ├── engine_for_finetuning.py
│ ├── engine_for_pretraining.py
│ ├── masking_generator.py
│ ├── optim_factory.py
│ ├── transforms.py
│ └── utils.py
├── linear_util/
│ ├── crop.py
│ ├── datasets.py
│ ├── engine_finetune.py
│ ├── lars.py
│ ├── lr_decay.py
│ ├── lr_sched.py
│ ├── misc.py
│ └── pos_embed.py
├── models/
│ ├── modeling_cae.py
│ ├── modeling_cae_helper.py
│ ├── modeling_discrete_vae.py
│ └── modeling_finetune.py
├── requirements.txt
├── scripts/
│ ├── cae_base_800e.sh
│ ├── cae_base_finetune.sh
│ ├── cae_large_1600e.sh
│ └── cae_large_finetune.sh
├── tokenizer-weights/
│ └── README
└── tools/
├── run_attentive.py
├── run_class_finetuning.py
├── run_linear.py
└── run_pretraining.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
================================================
FILE: README.md
================================================
# CAE: Context AutoEncoder for Self-Supervised Representation Learning
This is a PyTorch implementation of [CAE: Context AutoEncoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026).
## Highlights
- State-of-the-art MIM performance. Results in the paper are successfully reproduced.
## Installation
Clone the repo and install required packages.
```bash
pip install -r requirements.txt
# install apex
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
## Data Preparation
First, download ImageNet-1k from http://image-net.org/.
The directory structure is the standard layout of torchvision's datasets.ImageFolder. The training and validation data are expected to be in the train/ folder and val folder, respectively:
```
/path/to/imagenet/
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
```
Second, download the pretrained tokenizer.
```bash
TOKENIZER_PATH=/path/to/save/dall_e_tokenizer_weight
mkdir -p $TOKENIZER_PATH
wget -o $TOKENIZER_PATH/encoder.pkl https://cdn.openai.com/dall-e/encoder.pkl
wget -o $TOKENIZER_PATH/decoder.pkl https://cdn.openai.com/dall-e/decoder.pkl
```
## Pretraining
Here is an example that pretrains CAE-base on ImageNet-1K with 32 GPUs. Please see [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) for complete script.
```bash
OMP_NUM_THREADS=1 $PYTHON -m torch.distributed.launch \
--nproc_per_node=8 \
tools/run_pretraining.py \
--data_path ${DATA_PATH} \
--output_dir ${OUTPUT_DIR} \
--model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \
--batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \
--clip_grad 3.0 --layer_scale_init_value 0.1 \
--imagenet_default_mean_and_std \
--color_jitter 0 \
--drop_path 0.1 \
--sincos_pos_emb \
--mask_generator block \
--num_mask_patches 98 \
--decoder_layer_scale_init_value 0.1 \
--no_auto_resume \
--save_ckpt_freq 100 \
--exp_name $my_name \
--regressor_depth 4 \
--decoder_depth 4 \
--align_loss_weight 2
```
- `--num_mask_patches`: number of the input patches need be masked.
- `--batch_size`: batch size per GPU.
- Effective batch size = `number of GPUs` * `--batch_size`. So in the above example, the effective batch size is `64*32 = 2048`.
- `--lr`: learning rate.
- `--warmup_epochs`: learning rate warmup epochs. Warm up [10, 20, 40] epochs for [300, 800, 1600] pretrain epochs respectively.
- `--epochs`: total pretraining epochs.
- `--clip_grad`: clip gradient norm.
- `--drop_path`: stochastic depth rate.
- `--imagenet_default_mean_and_std`: enable this for ImageNet-1k pretraining, i.e., `(0.485, 0.456, 0.406)` for mean and `(0.229, 0.224, 0.225)` for std. For other pretraining data, use `(0.5, 0.5, 0.5)` for mean and `(0.5, 0.5, 0.5)` for std by default.
- `--layer_scale_init_value`: 0.1 for base, 1e-5 for large, set 0 to disable layerscale. We set `--decoder_layer_scale_init_value` the same as this.
- `--sincos_pos_emb`: adopt sin-cos positional embedding during pretraining.
- `--regressor_depth`: length of the regressor.
- `--decoder_depth`: length of the decoder.
- `--align_loss_weight`: weight for alignment loss. 2 by default.
Warmup epochs for 300/800/1600 epochs pretraining are 10/20/40.
For CAE-large, please refer to [scripts/cae_large_1600e.sh](scripts/cae_large_1600e.sh).
## Results
Here provides the results of CAE-base/CAE-large for these evaluation tasks:
- Linear probing
- Attentive probing
- Fine-tuning
- Semantic segmentation
- Object detection and instance segmentation
Pretrained weights and logs are available ([Google Drive](https://drive.google.com/drive/folders/1wwhg7nj2GQuU9uthVuQLkEEXEjx90G7g?usp=sharing), [Baidu Cloud [Code: 4kil]](https://pan.baidu.com/s/15eZGoI72iLupLrOHqmOM9w)). *: from CAE paper.
| Model | Pretraining data | #Epoch | Linear | Attentive | Fine-tuning | ADE Seg | COCO Det | COCO InstSeg |
| ---------- | ---------------- | ------ | ------ | --------- | ----------- | ------- | -------- | ------------ |
| MAE-base* | ImageNet-1K | 1600 | 67.8 | 74.2 | 83.6 | 48.1 | 48.4 | 42.6 |
| MAE-large* | ImageNet-1K | 1600 | 76.0 | 78.8 | 86.0 | 53.6 | 54.0 | 47.1 |
| CAE-base | ImageNet-1K | 300 | 64.5 | 74.0 | 83.6 | 48.1 | 48.3 | 42.7 |
| CAE-base | ImageNet-1K | 800 | 68.9 | 75.9 | 83.8 | 49.7 | 49.9 | 43.9 |
| CAE-base | ImageNet-1K | 1600 | 70.3 | 77.2 | 83.9 | 50.3 | 50.3 | 44.2 |
| CAE-large | ImageNet-1K | 1600 | 77.8 | 81.2 | 86.2 | 54.9 | 54.5 | 47.5 |
### Linear Probing
- Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs).
- For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`.
### Attentive Probing
- Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs).
- For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`.
### Fine-tuning
- Please refer to [scripts/cae_base_finetune.sh](scripts/cae_base_finetune.sh) (32 GPUs).
- For CAE-large, please refer to [scripts/cae_large_finetune.sh](scripts/cae_large_finetune.sh) (32 GPUs).
### Segmentation & Detection
- Please refer to [downstream_tasks](./downstream_tasks) dir to get started.
## Acknowledgement
This repository is built using the [BEiT](https://github.com/microsoft/unilm/edit/master/beit) and [MMSelfSup](https://github.com/open-mmlab/mmselfsup), thanks for their open-source code! Thanks also to the CAE authors for their excellent work!
## Citation
```bibtex
@article{ContextAutoencoder2022,
title={Context Autoencoder for Self-Supervised Representation Learning},
author={Chen, Xiaokang and Ding, Mingyu and Wang, Xiaodi and Xin, Ying and Mo, Shentong and Wang, Yunhao and Han, Shumin and Luo, Ping and Zeng, Gang and Wang, Jingdong},
journal={arXiv preprint arXiv:2202.03026},
year={2022}
}
```
================================================
FILE: dall_e/__init__.py
================================================
import io, requests
import torch
import torch.nn as nn
from dall_e.encoder import Encoder
from dall_e.decoder import Decoder
from dall_e.utils import map_pixels, unmap_pixels
def load_model(path: str, device: torch.device = None) -> nn.Module:
if path.startswith('http://') or path.startswith('https://'):
resp = requests.get(path)
resp.raise_for_status()
with io.BytesIO(resp.content) as buf:
return torch.load(buf, map_location=device)
else:
with open(path, 'rb') as f:
return torch.load(f, map_location=device)
================================================
FILE: dall_e/decoder.py
================================================
import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from functools import partial
from dall_e.utils import Conv2d
@attr.s(eq=False, repr=False)
class DecoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers ** 2)
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 1)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
class Decoder(nn.Module):
group_count: int = 4
n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
def __attrs_post_init__(self) -> None:
super().__init__()
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(OrderedDict([
('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.vocab_size:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
================================================
FILE: dall_e/encoder.py
================================================
import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from functools import partial
from dall_e.utils import Conv2d
@attr.s(eq=False, repr=False)
class EncoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers ** 2)
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
class Encoder(nn.Module):
group_count: int = 4
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
def __attrs_post_init__(self) -> None:
super().__init__()
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(OrderedDict([
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
================================================
FILE: dall_e/utils.py
================================================
import attr
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
logit_laplace_eps: float = 0.1
@attr.s(eq=False)
class Conv2d(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
use_float16: bool = attr.ib(default=True)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
device=self.device, requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
requires_grad=self.requires_grad)
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
def map_pixels(x: torch.Tensor) -> torch.Tensor:
if x.dtype != torch.float:
raise ValueError('expected input to have type float')
return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError('expected input to be 4d')
if x.dtype != torch.float:
raise ValueError('expected input to have type float')
return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)
================================================
FILE: downstream_tasks/detection/README.md
================================================
# COCO Detection and Instance segmentation with CAE
# Installation
Please install [PyTorch](https://pytorch.org/). This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. To get the full dependencies, please run:
```bash
pip3 install -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.1/index.html mmcv-full==1.3.9
pip3 install pytest-runner scipy tensorboardX faiss-gpu==1.6.1 tqdm lmdb sklearn pyarrow==2.0.0 timm DALL-E munkres six einops
# install apex
pip3 install git+https://github.com/NVIDIA/apex \
--no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext"
# install mmdetection for object detection & instance segmentation
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection
pip3 install -r requirements/build.txt
pip3 install -v -e .
cd ..
```
## Fine-tuning with Mask R-CNN
#### We use 16 GPUs for these experiments, $NNODES = 2.
- To train ViT-B/16 with Mask R-CNN as the task layer, run:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=$PORT \
evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \
--launcher pytorch \
--work-dir $OUTPUT_DIR \
--no-validate \
--deterministic \
--cfg-options model.backbone.use_checkpoint=True \
model.pretrained=$PRETRAINED \
${@:6}
```
- To train ViT-L/16 with Mask R-CNN as the task layer, run:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=$PORT \
evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \
--launcher pytorch \
--work-dir $OUTPUT_DIR \
--no-validate \
--deterministic \
--cfg-options model.backbone.use_checkpoint=True \
model.pretrained=$PRETRAINED \
${@:6}
```
- To evaluate Mask R-CNN, run:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
evaluation/object_detection/test.py \
$CONFIG \
$MODEL \
--launcher pytorch \
--eval bbox segm \
--cfg-options model.backbone.use_checkpoint=True \
${@:6}
```
## Results (pretrined models are trained on ImageNet-1K without label)
| Backbone | #Pretrained Epoch | Object Det | Instance Seg |
| -------- | ----------------- | ---------- | ------------ |
| ViT-B | 300 | 48.3 | 42.7 |
| ViT-B | 800 | 49.9 | 43.9 |
| ViT-B | 1600 | 50.3 | 44.2 |
| ViT-L | 1600 | 54.5 | 47.5 |
## Acknowledgement
This repository is built using the [IBOT repository](https://github.com/bytedance/ibot). Thanks for their open-source code!
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/datasets/coco_instance.py
================================================
dataset_type = 'CocoDataset'
data_root = '/path/to/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/default_runtime.py
================================================
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
================================================
ettings
model = dict(
type='CascadeRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_swin_fpn.py
================================================
# model settings
model = dict(
type='CascadeRCNN',
pretrained=None,
backbone=dict(
type='SwinTransformer',
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
use_checkpoint=False),
neck=dict(
type='FPN',
in_channels=[96, 192, 384, 768],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_vit_fpn.py
================================================
# model settings
model = dict(
type='CascadeRCNN',
pretrained=None,
backbone=dict(
type='VisionTransformer',
img_size=[672, 1092],
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
out_indices=(3, 5, 7, 11),
use_checkpoint=False),
neck=dict(
type='FPN',
in_channels=[384, 384, 384, 384],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='MaskRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_vit_fpn.py
================================================
# model settings
model = dict(
type='MaskRCNN',
pretrained=None,
backbone=dict(
type='VisionTransformer',
img_size=[672, 1092],
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
out_indices=(3, 5, 7, 11),
use_checkpoint=False),
neck=dict(
type='FPN',
in_channels=[384, 384, 384, 384],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/schedules/schedule_1x.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from timm, mmdet, and swin code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
https://github.com/open-mmlab/mmdetection
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
"""
_base_ = [
'../_base_/models/mask_rcnn_vit_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
backbone=dict(
embed_dim=768,
depth=12,
num_heads=12,
init_values=0.1,
mlp_ratio=4.,
drop_path_rate=0.2, #see if 0.1 larger than vit-small is better
use_abs_pos_emb=False,
use_sincos_pos_emb=True,
use_rel_pos_bias=False,
),
neck=dict(in_channels=[768, 768, 768, 768]),
roi_head=dict(
bbox_head=dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.75))
lr_config = dict(step=[9, 11])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py
================================================
#Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from timm, mmdet, and swin code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
https://github.com/open-mmlab/mmdetection
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
"""
_base_ = [
'../_base_/models/mask_rcnn_vit_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
find_unused_parameters = False
model = dict(
backbone=dict(
embed_dim=1024,
depth=24,
num_heads=16,
init_values=0.00001,
mlp_ratio=4.,
drop_path_rate=0.2, #see if 0.1 larger than vit-small is better
use_abs_pos_emb=False,
use_sincos_pos_emb=True,
use_rel_pos_bias=False,
out_indices=[7, 11, 15, 23],
),
neck=dict(in_channels=[1024, 1024, 1024, 1024]),
roi_head=dict(
bbox_head=dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.85))
lr_config = dict(step=[9, 11])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/__init__.py
================================================
# -*- coding: utf-8 -*-
from .checkpoint import load_checkpoint
from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor
from .register_backbone import VisionTransformer
__all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor']
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/checkpoint.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Copy-paste from mmcv library:
https://github.com/open-mmlab/mmcv/
"""
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
from torch.optim import Optimizer
from torch.nn import functional as F
import mmcv
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.utils import mkdir_or_exist
from mmcv.runner import get_dist_info
from scipy import interpolate
import numpy as np
import math
ENV_MMCV_HOME = 'MMCV_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url, model_dir=None, map_location="cpu"):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelscloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array(
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'module' in checkpoint:
state_dict = checkpoint['module']
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
all_keys = list(state_dict.keys())
if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'):
# NOTE: remove all decoder keys
all_keys = [key for key in all_keys if key.startswith('encoder.')]
for key in all_keys:
new_key = key.replace('encoder.','')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
for key in list(state_dict.keys()):
if key.startswith('decoder.'):
state_dict.pop(key)
# NOTE: replace norm with fc_norm
for key in list(state_dict.keys()):
if key.startswith('norm.'):
new_key = key.replace('norm.','fc_norm.')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
# reshape absolute position embedding for Swin
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H*W:
logger.warning("Error in loading absolute_pos_embed, pass")
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
rank, _ = get_dist_info()
if "rel_pos_bias.relative_position_bias_table" in state_dict:
if rank == 0:
rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"]
state_dict["relative_position_bias_table"] = rel_pos_bias
state_dict.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(state_dict.keys())
for key in all_keys:
if "relative_position_index" in key:
state_dict.pop(key)
if "relative_position_bias_table" in key and key in model.state_dict():
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27
dst_size_0 = dst_patch_shape[0] * 2 - 1 # 42
dst_size_1 = dst_patch_shape[1] * 2 - 1 # 68
if src_size != dst_size_0:
if rank == 0:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size_0, dst_size_1))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size_0 // 2:
right = q
else:
left = q
dis_0 = []
cur = 1
for i in range(src_size // 2):
dis_0.append(cur)
cur += q ** (i + 1)
r_ids_0 = [-_ for _ in reversed(dis_0)]
top, bottom = 1.01, 1.5
while bottom - top > 1e-6:
q = (top + bottom) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size_1 // 2:
bottom = q
else:
top = q
dis_1 = []
cur = 1
for i in range(src_size // 2):
dis_1.append(cur)
cur += q ** (i + 1)
r_ids_1 = [-_ for _ in reversed(dis_1)]
# if q > 1.13492:
# q = 1.13492
x = r_ids_0 + [0] + dis_0
y = r_ids_1 + [0] + dis_1
t_0 = dst_size_0 // 2.0
t_1 = dst_size_1 // 2.0
dx = np.arange(-t_0, t_0 + 0.1, 1.0)
dy = np.arange(-t_1, t_1 + 0.1, 1.0)
if rank == 0:
print("x = {}".format(x))
print("dx = {}".format(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
else:
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) #
if src_size != dst_size:
if rank == 0:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.13492:
# q = 1.13492
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
if rank == 0:
print("x = {}".format(x))
print("dx = {}".format(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
#new_size = int(num_patches ** 0.5)
new_size_w = model.patch_embed.num_patches_w
new_size_h = model.patch_embed.num_patches_h
# class_token and dist_token are kept unchanged
if orig_size != new_size_h or orig_size != new_size_w:
if rank == 0:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size_w, new_size_h))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size_w, new_size_h), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
# interpolate position bias table if needed
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k and k in model.state_dict()]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {table_key}, pass")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2), mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelscloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/layer_decay_optimizer_constructor.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from BEiT library:
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py
"""
import json
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
from mmcv.runner import get_dist_info
def get_num_layer_for_vit(var_name, num_max_layer):
if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
return 0
elif var_name.startswith("backbone.patch_embed"):
return 0
elif var_name.startswith("backbone.blocks"):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return num_max_layer - 1
@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
parameter_groups = {}
print(self.paramwise_cfg)
num_layers = self.paramwise_cfg.get('num_layers') + 2
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers))
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
layer_id = get_num_layer_for_vit(name, num_layers)
group_name = "layer_%d_%s" % (layer_id, group_name)
if group_name not in parameter_groups:
scale = layer_decay_rate ** (num_layers - layer_id - 1)
parameter_groups[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"param_names": [],
"lr_scale": scale,
"group_name": group_name,
"lr": scale * self.base_lr,
}
parameter_groups[group_name]["params"].append(param)
parameter_groups[group_name]["param_names"].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
"param_names": parameter_groups[key]["param_names"],
"lr_scale": parameter_groups[key]["lr_scale"],
"lr": parameter_groups[key]["lr"],
"weight_decay": parameter_groups[key]["weight_decay"],
}
print("Param groups = %s" % json.dumps(to_display, indent=2))
# state_dict = module.state_dict()
# for group_name in parameter_groups:
# group = parameter_groups[group_name]
# for name in group["param_names"]:
# group["params"].append(state_dict[name])
params.extend(parameter_groups.values())
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/prepare_rpe.py
================================================
import torch
import numpy as np
from scipy import interpolate
from mmcv.runner import get_dist_info
import torch.nn as nn
def rpe_index(window_size):
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
def prepare_rpe(rel_pos_bias, src_patch_shape, dst_patch_shape):
src_num_pos, num_attn_heads = rel_pos_bias.size() # 732
rank, _ = get_dist_info()
dst_num_pos = (dst_patch_shape[0]*2 -1) * (dst_patch_shape[1]*2 -1) + 3
if dst_patch_shape[0] != src_patch_shape[0] or dst_patch_shape[1] != src_patch_shape[1]:
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
# src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27
src_size_0, src_size_1 = src_patch_shape[0] * 2 - 1, src_patch_shape[1]*2 -1
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
dst_size_0 = dst_patch_shape[0] * 2 - 1 # 42
dst_size_1 = dst_patch_shape[1] * 2 - 1 # 68
dim = rel_pos_bias.shape[-1]
rel_pos_bias = rel_pos_bias.reshape(1 , src_size_0, src_size_1, dim).permute(0, 3, 1, 2)
new_rel_pos_bias = nn.functional.interpolate(rel_pos_bias, scale_factor=(dst_size_0 / src_size_0, dst_size_1 / dst_size_1), mode='bicubic',)
new_rel_pos_bias = new_rel_pos_bias.permute(0, 2, 3, 1).view(1, -1, dim).squeeze(0)
new_rel_pos_bias = torch.cat((new_rel_pos_bias, extra_tokens), dim=0)
else:
new_rel_pos_bias = rel_pos_bias
# get rpe_index
relative_position_index = rpe_index(dst_patch_shape)
new_rel_pos_bias = new_rel_pos_bias[relative_position_index.view(-1)].view(
dst_patch_shape[0] * dst_patch_shape[1] + 1,
dst_patch_shape[0] * dst_patch_shape[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
new_rel_pos_bias = new_rel_pos_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return new_rel_pos_bias
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/register_backbone.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from mmcv_custom import load_checkpoint
from mmdet.utils import get_root_logger
from mmdet.models.builder import BACKBONES
from models import VisionTransformer
from .prepare_rpe import prepare_rpe
import time
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.num_patches_w = img_size[0] // patch_size
self.num_patches_h = img_size[1] // patch_size
num_patches = self.num_patches_w * self.num_patches_h
self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, mask=None):
B, C, H, W = x.shape
return self.proj(x)
@BACKBONES.register_module()
class VisionTransformer(VisionTransformer):
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
with_fpn=True,
frozen_stages=-1,
out_indices=[3, 5, 7, 11],
out_with_norm=False,
use_checkpoint=False,
**kwargs):
super(VisionTransformer, self).__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
**kwargs)
# support non-square image as input
if len(img_size) == 1:
img_size = img_size * 2
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
if self.use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
elif self.use_sincos_pos_emb:
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else:
self.pos_embed = None
self.patch_size = patch_size
self.with_fpn = with_fpn
self.frozen_stages = frozen_stages
self.out_indices = out_indices
self.use_checkpoint = use_checkpoint
if not out_with_norm:
self.norm = nn.Identity()
if with_fpn and patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif with_fpn and patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
else:
logger = get_root_logger()
logger.info('Build model without FPN.')
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(VisionTransformer, self).train(mode)
self._freeze_stages()
if self.pos_embed is not None:
if self.pos_embed.requires_grad:
print("=================pos_embed update ================")
else:
print("=================pos_embed static ================")
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
self.cls_token.requires_grad = False
if self.pos_embed is not None and self.use_sincos_pos_emb == True:
self.pos_embed.requires_grad = False
self.pos_drop.eval()
for i in range(1, self.frozen_stages + 1):
if i == len(self.blocks):
norm_layer = getattr(self, 'norm') #f'norm{i-1}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
m = self.blocks[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
self.apply(self._init_weights)
logger = get_root_logger()
if os.path.isfile(pretrained):
load_checkpoint(self, pretrained, strict=False, logger=logger)
else:
logger.info(f"checkpoint path {pretrained} is invalid, we skip it and initialize net randomly")
elif pretrained is None:
self.apply(self._init_weights)
else:
raise TypeError('pretrained must be a str or None')
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
if npatch == N and w0 == self.patch_embed.num_patches_w and h0 == self.patch_embed.num_patches_h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
tmp=patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2),
scale_factor=(w0 / self.patch_embed.num_patches_w, h0 / self.patch_embed.num_patches_h),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, x):
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
time_begin = time.time()
if self.relative_position_bias_table is None:
x_rpe = None
else:
dst_rpe_shape = (Wp, Hp) if H <= W else(Hp, Wp)
x_rpe = prepare_rpe(self.relative_position_bias_table, self.patch_embed.patch_shape, dst_rpe_shape)
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_rpe)
else:
x = blk(x, x_rpe)
if i in self.out_indices:
xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
time_backbone = time.time()
if self.with_fpn:
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
time_end = time.time()
return tuple(features)
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/__init__.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import save_checkpoint
from .epoch_based_runner import EpochBasedRunnerAmp
__all__ = [
'EpochBasedRunnerAmp', 'save_checkpoint'
]
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/checkpoint.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Copy-paste from mmcv library:
https://github.com/open-mmlab/mmcv/
"""
import os.path as osp
import time
import torch
import mmcv
try:
import apex
except:
print('apex is not installed')
from tempfile import TemporaryDirectory
from torch.optim import Optimizer
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
``optimizer``, ``amp``. By default ``meta`` will contain version
and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# save amp state dict in the checkpoint
checkpoint['amp'] = apex.amp.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelscloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/epoch_based_runner.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Copy-paste from mmcv library:
https://github.com/open-mmlab/mmcv/
"""
import os.path as osp
import platform
import shutil
import torch
import mmcv
try:
import apex
except:
print('apex is not installed')
from torch.optim import Optimizer
from mmcv.runner import RUNNERS, EpochBasedRunner
from .checkpoint import save_checkpoint
@RUNNERS.register_module()
class EpochBasedRunnerAmp(EpochBasedRunner):
"""Epoch-based Runner with AMP support.
This runner train models epoch by epoch.
"""
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
elif isinstance(meta, dict):
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
if 'amp' in checkpoint:
apex.amp.load_state_dict(checkpoint['amp'])
self.logger.info('load amp state dict')
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/test.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from mmdetection library:
https://github.com/open-mmlab/mmdetection/blob/master/tools/test.py
"""
import argparse
import os
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.models import build_detector
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--show-score-thr',
type=float,
default=0.3,
help='score threshold (default: 0.3)')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.eval_options:
raise ValueError(
'--options and --eval-options cannot be both '
'specified, --options is deprecated in favor of --eval-options')
if args.options:
warnings.warn('--options is deprecated in favor of --eval-options')
args.eval_options = args.options
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if cfg.model.get('neck'):
if isinstance(cfg.model.neck, list):
for neck_cfg in cfg.model.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif cfg.model.neck.get('rfp_backbone'):
if cfg.model.neck.rfp_backbone.get('pretrained'):
cfg.model.neck.rfp_backbone.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
if samples_per_gpu > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
print(dataset.evaluate(outputs, **eval_kwargs))
if __name__ == '__main__':
main()
================================================
FILE: downstream_tasks/detection/evaluation/object_detection/train.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from mmdetection library:
https://github.com/open-mmlab/mmdetection/blob/master/tools/train.py
"""
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
================================================
FILE: downstream_tasks/detection/loader.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import math
import numpy as np
from torchvision.datasets import ImageFolder
class ImageFolderInstance(ImageFolder):
def __getitem__(self, index):
img, target = super(ImageFolderInstance, self).__getitem__(index)
return img, target, index
class ImageFolderMask(ImageFolder):
def __init__(self, *args, patch_size, pred_ratio, pred_ratio_var, pred_aspect_ratio,
pred_shape='block', pred_start_epoch=0, **kwargs):
super(ImageFolderMask, self).__init__(*args, **kwargs)
self.psz = patch_size
self.pred_ratio = pred_ratio[0] if isinstance(pred_ratio, list) and \
len(pred_ratio) == 1 else pred_ratio
self.pred_ratio_var = pred_ratio_var[0] if isinstance(pred_ratio_var, list) and \
len(pred_ratio_var) == 1 else pred_ratio_var
if isinstance(self.pred_ratio, list) and not isinstance(self.pred_ratio_var, list):
self.pred_ratio_var = [self.pred_ratio_var] * len(self.pred_ratio)
self.log_aspect_ratio = tuple(map(lambda x: math.log(x), pred_aspect_ratio))
self.pred_shape = pred_shape
self.pred_start_epoch = pred_start_epoch
def get_pred_ratio(self):
if hasattr(self, 'epoch') and self.epoch < self.pred_start_epoch:
return 0
if isinstance(self.pred_ratio, list):
pred_ratio = []
for prm, prv in zip(self.pred_ratio, self.pred_ratio_var):
assert prm >= prv
pr = random.uniform(prm - prv, prm + prv) if prv > 0 else prm
pred_ratio.append(pr)
pred_ratio = random.choice(pred_ratio)
else:
assert self.pred_ratio >= self.pred_ratio_var
pred_ratio = random.uniform(self.pred_ratio - self.pred_ratio_var, self.pred_ratio + \
self.pred_ratio_var) if self.pred_ratio_var > 0 else self.pred_ratio
return pred_ratio
def set_epoch(self, epoch):
self.epoch = epoch
def __getitem__(self, index):
output = super(ImageFolderMask, self).__getitem__(index)
masks = []
for img in output[0]:
try:
H, W = img.shape[1] // self.psz, img.shape[2] // self.psz
except:
# skip non-image
continue
high = self.get_pred_ratio() * H * W
if self.pred_shape == 'block':
# following BEiT (https://arxiv.org/abs/2106.08254), see at
# https://github.com/microsoft/unilm/blob/b94ec76c36f02fb2b0bf0dcb0b8554a2185173cd/beit/masking_generator.py#L55
mask = np.zeros((H, W), dtype=bool)
mask_count = 0
while mask_count < high:
max_mask_patches = high - mask_count
delta = 0
for attempt in range(10):
low = (min(H, W) // 3) ** 2
target_area = random.uniform(low, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < W and h < H:
top = random.randint(0, H - h)
left = random.randint(0, W - w)
num_masked = mask[top: top + h, left: left + w].sum()
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
if delta == 0:
break
else:
mask_count += delta
elif self.pred_shape == 'rand':
mask = np.hstack([
np.zeros(H * W - int(high)),
np.ones(int(high)),
]).astype(bool)
np.random.shuffle(mask)
mask = mask.reshape(H, W)
else:
# no implementation
assert False
masks.append(mask)
return output + (masks,)
================================================
FILE: downstream_tasks/detection/models/__init__.py
================================================
from .vision_transformer import VisionTransformer, vit_tiny, vit_small, vit_base, vit_large
from .swin_transformer import SwinTransformer, swin_tiny, swin_small, swin_base, swin_large
================================================
FILE: downstream_tasks/detection/models/head.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import utils
from utils import trunc_normal_
class CSyncBatchNorm(nn.SyncBatchNorm):
def __init__(self,
*args,
with_var=False,
**kwargs):
super(CSyncBatchNorm, self).__init__(*args, **kwargs)
self.with_var = with_var
def forward(self, x):
# center norm
self.training = False
if not self.with_var:
self.running_var = torch.ones_like(self.running_var)
normed_x = super(CSyncBatchNorm, self).forward(x)
# udpate center
self.training = True
_ = super(CSyncBatchNorm, self).forward(x)
return normed_x
class PSyncBatchNorm(nn.SyncBatchNorm):
def __init__(self,
*args,
bunch_size,
**kwargs):
procs_per_bunch = min(bunch_size, utils.get_world_size())
assert utils.get_world_size() % procs_per_bunch == 0
n_bunch = utils.get_world_size() // procs_per_bunch
#
ranks = list(range(utils.get_world_size()))
print('---ALL RANKS----\n{}'.format(ranks))
rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)]
print('---RANK GROUPS----\n{}'.format(rank_groups))
process_groups = [torch.distributed.new_group(pids) for pids in rank_groups]
bunch_id = utils.get_rank() // procs_per_bunch
process_group = process_groups[bunch_id]
print('---CURRENT GROUP----\n{}'.format(process_group))
super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs)
class CustomSequential(nn.Sequential):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
def forward(self, input):
for module in self:
dim = len(input.shape)
if isinstance(module, self.bn_types) and dim > 2:
perm = list(range(dim - 1)); perm.insert(1, dim - 1)
inv_perm = list(range(dim)) + [1]; inv_perm.pop(1)
input = module(input.permute(*perm)).permute(*inv_perm)
else:
input = module(input)
return input
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None,
nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs):
super().__init__()
norm = self._build_norm(norm, hidden_dim)
last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs)
act = self._build_act(act)
nlayers = max(nlayers, 1)
if nlayers == 1:
if bottleneck_dim > 0:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
self.mlp = nn.Linear(in_dim, out_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if norm is not None:
layers.append(norm)
layers.append(act)
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if norm is not None:
layers.append(norm)
layers.append(act)
if bottleneck_dim > 0:
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
else:
layers.append(nn.Linear(hidden_dim, out_dim))
self.mlp = CustomSequential(*layers)
self.apply(self._init_weights)
if bottleneck_dim > 0:
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
else:
self.last_layer = None
self.last_norm = last_norm
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
if self.last_layer is not None:
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
if self.last_norm is not None:
x = self.last_norm(x)
return x
def _build_norm(self, norm, hidden_dim, **kwargs):
if norm == 'bn':
norm = nn.BatchNorm1d(hidden_dim, **kwargs)
elif norm == 'syncbn':
norm = nn.SyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'csyncbn':
norm = CSyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'psyncbn':
norm = PSyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'ln':
norm = nn.LayerNorm(hidden_dim, **kwargs)
else:
assert norm is None, "unknown norm type {}".format(norm)
return norm
def _build_act(self, act):
if act == 'relu':
act = nn.ReLU()
elif act == 'gelu':
act = nn.GELU()
else:
assert False, "unknown act type {}".format(act)
return act
class iBOTHead(DINOHead):
def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None,
nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True,
shared_head=False, **kwargs):
super(iBOTHead, self).__init__(*args,
norm=norm,
act=act,
last_norm=last_norm,
nlayers=nlayers,
hidden_dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
norm_last_layer=norm_last_layer,
**kwargs)
if not shared_head:
if bottleneck_dim > 0:
self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False))
self.last_layer2.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer2.weight_g.requires_grad = False
else:
self.mlp2 = nn.Linear(hidden_dim, patch_out_dim)
self.last_layer2 = None
self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs)
else:
if bottleneck_dim > 0:
self.last_layer2 = self.last_layer
else:
self.mlp2 = self.mlp[-1]
self.last_layer2 = None
self.last_norm2 = self.last_norm
def forward(self, x):
if len(x.shape) == 2:
return super(iBOTHead, self).forward(x)
if self.last_layer is not None:
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x1 = self.last_layer(x[:, 0])
x2 = self.last_layer2(x[:, 1:])
else:
x = self.mlp[:-1](x)
x1 = self.mlp[-1](x[:, 0])
x2 = self.mlp2(x[:, 1:])
if self.last_norm is not None:
x1 = self.last_norm(x1)
x2 = self.last_norm2(x2)
return x1, x2
================================================
FILE: downstream_tasks/detection/models/swin_transformer.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from Swin-Transformer libarary:
https://github.com/facebookresearch/dino
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
"""
import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from math import sqrt
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super(WindowAttention, self).__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn_out = attn
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_out
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
@staticmethod
def compute_macs(module, input, output):
B, N, C = input[0].shape
module.__flops__ += module.flops(N) * B
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = input_resolution[0]
self.W = input_resolution[1]
self.attn_mask_dict = {} # {self.H: self.create_attn_mask(self.H, self.W)}
def create_attn_mask(self, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, L, C = x.shape
H = int(sqrt(L))
W = H
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
if H is self.attn_mask_dict.keys():
attn_mask = self.attn_mask_dict[H]
else:
self.attn_mask_dict[H] = self.create_attn_mask(H, W).to(x.device)
attn_mask = self.attn_mask_dict[H]
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
H = int(sqrt(L))
W = H
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
x, _ = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def forward_with_features(self, x):
fea = []
for blk in self.blocks:
x, _ = blk(x)
fea.append(x)
if self.downsample is not None:
x = self.downsample(x)
return x, fea
def forward_with_attention(self, x):
attns = []
for blk in self.blocks:
x, attn = blk(x)
attns.append(attn)
if self.downsample is not None:
x = self.downsample(x)
return x, attns
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
# # FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x.transpose(1, 2).reshape(B, C, H, W)
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True,
return_all_tokens=False, use_mean_pooling=True, masked_im_modeling=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.return_all_tokens = return_all_tokens
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
# masked image modeling
self.masked_im_modeling = masked_im_modeling
if masked_im_modeling:
self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
# todo: to be implemented
return {'relative_position_bias_table'}
def forward(self, x, return_all_tokens=None, mask=None):
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x_region = self.norm(x) # B L C
x = self.avgpool(x_region.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return_all_tokens = self.return_all_tokens if \
return_all_tokens is None else return_all_tokens
if return_all_tokens:
return torch.cat([x.unsqueeze(1), x_region], dim=1)
return x
def get_selfattention(self, x, n=1):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
if n==1:
return self.get_last_selfattention(x)
else:
return self.get_all_selfattention(x)
def get_last_selfattention(self, x):
for i, layer in enumerate(self.layers):
if i < len(self.layers) - 1:
x = layer(x)
else:
x, attns = layer.forward_with_attention(x)
return attns[-1]
def get_all_selfattention(self, x):
attn_out = []
for layer in self.layers:
x, attns = layer.forward_with_attention(x)
attn_out += attns
return attn_out
def get_intermediate_layers(self, x, n=1, return_patch_avgpool=False):
num_blks = sum(self.depths)
start_idx = num_blks - n
sum_cur = 0
for i, d in enumerate(self.depths):
sum_cur_new = sum_cur + d
if start_idx >= sum_cur and start_idx < sum_cur_new:
start_stage = i
start_blk = start_idx - sum_cur
sum_cur = sum_cur_new
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output = []
s = 0
for i, layer in enumerate(self.layers):
x, fea = layer.forward_with_features(x)
if i >= start_stage:
for x_ in fea[start_blk:]:
if i == len(self.layers)-1: # use the norm in the last stage
x_ = self.norm(x_)
x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C
if return_patch_avgpool:
x_o = x_avg
else:
x_o = torch.cat((x_avg.unsqueeze(1), x_), dim=1)
# print(f'Stage {i}, x_o {x_o.shape}')
output.append(x_o)
start_blk = 0
#return torch.cat(output, dim=-1)
return output
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
if dist.get_rank() == 0:
print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}")
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] is '*'
or 'relative_position_index' not in k
or 'attn_mask' not in k
)
if need_init:
if verbose:
logging.info(f'=> init {k} from {pretrained}')
if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
relative_position_bias_table_pretrained = v
relative_position_bias_table_current = model_dict[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((L1, nH1), (L2, nH2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
absolute_pos_embed_pretrained = v
absolute_pos_embed_current = model_dict[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((1, L1, C1), (1, L2, C2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
need_init_state_dict[k] = v
self.load_state_dict(need_init_state_dict, strict=False)
def freeze_pretrained_layers(self, frozen_layers=[]):
for name, module in self.named_modules():
if (
name.split('.')[0] in frozen_layers
or '.'.join(name.split('.')[0:2]) in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
):
for _name, param in module.named_parameters():
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
for name, param in self.named_parameters():
if (
name.split('.')[0] in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
and param.requires_grad is True
):
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
return self
def get_num_layers(self):
#return len(self.layers)
return sum(self.depths)
def mask_model(self, x, mask):
# extend mask for hierarchical features
if x.shape[-2:] != mask.shape[-2:]:
htimes, wtimes = np.array(x.shape[-2:]) // np.array(mask.shape[-2:])
mask = mask.repeat_interleave(htimes, -2).repeat_interleave(wtimes, -1)
# mask embed
x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)
return x
@register_model
def swin_tiny(window_size=7, **kwargs):
model = SwinTransformer(
window_size=window_size, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.1), **kwargs)
return model
@register_model
def swin_small(window_size=7, **kwargs):
model = SwinTransformer(
window_size=window_size, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],
mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)
return model
@register_model
def swin_base(window_size=7, **kwargs):
model = SwinTransformer(
window_size=window_size, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)
return model
@register_model
def swin_large(window_size=7, **kwargs):
model = SwinTransformer(
window_size=window_size, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48],
mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs)
return model
================================================
FILE: downstream_tasks/detection/models/vision_transformer.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from utils import trunc_normal_
from timm.models.registry import register_model
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
#self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
####add by wxd
self.qkv = nn.Linear(dim, dim * 3, bias=False)
all_head_dim = head_dim * self.num_heads
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, x_rel_pos_bias = None):
B, N, C = x.shape
#qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
########add by wxd
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
# if self.relative_position_bias_table is not None:
# relative_position_bias = \
# self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
# self.window_size[0] * self.window_size[1] + 1,
# self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
# print("################before relative:", relative_position_bias.shape)
# relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
# print("################after relative:", relative_position_bias.shape)
# relative_position_bias = intepolate_rpe(relative_position_bias)
# print("################ater inter relative:", relative_position_bias.shape)
if x_rel_pos_bias is not None:
attn = attn + x_rel_pos_bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, init_values=0):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, x_rel_pos_bias=None, return_attention=False):
y, attn = self.attn(self.norm1(x), x_rel_pos_bias)
if return_attention:
return attn
if self.gamma_1 is None:
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * y)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
#class PatchEmbed(nn.Module):
# """ Image to Patch Embedding
# """
# def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
# super().__init__()
# num_patches = (img_size // patch_size) * (img_size // patch_size)
# self.img_size = img_size
# self.patch_size = patch_size
# self.num_patches = num_patches
# print("#################patch in!!!")
# self.patch_shape = (img_size // patch_size, img_size // patch_size)
#
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
#
# def forward(self, x):
# B, C, H, W = x.shape
# return self.proj(x)
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.num_patches_w = img_size[0] // patch_size
self.num_patches_h = img_size[1] // patch_size
num_patches = self.num_patches_w * self.num_patches_h
self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
print("##############patch here!!!")
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, mask=None):
B, C, H, W = x.shape
return self.proj(x)
class VisionTransformer(nn.Module):
""" Vision Transformer """
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False,
init_values=0, use_sincos_pos_emb=False, use_abs_pos_emb=False, use_rel_pos_bias=False, use_mean_pooling=False, masked_im_modeling=False):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.return_all_tokens = return_all_tokens
print("############use_abs_pos:", use_abs_pos_emb)
print("############use_sincos_pos:", use_sincos_pos_emb)
print("############use_rel_pos_bias:", use_rel_pos_bias)
self.use_abs_pos_emb = use_abs_pos_emb
self.use_sincos_pos_emb = use_sincos_pos_emb
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
self.use_rel_pos_bias = use_rel_pos_bias
if self.use_rel_pos_bias:
print("=================use RelativePositionBias===================")
window_size=self.patch_embed.patch_shape
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_abs_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
# masked image modeling
self.masked_im_modeling = masked_im_modeling
if masked_im_modeling:
self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
print("############self.pos_embed:", self.pos_embed)
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
if self.pos_embed is not None:
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward(self, x, return_all_tokens=None, mask=None):
# mim
if self.masked_im_modeling:
assert mask is not None
x = self.prepare_tokens(x, mask=mask)
else:
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.fc_norm is not None:
x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1))
return_all_tokens = self.return_all_tokens if \
return_all_tokens is None else return_all_tokens
if return_all_tokens:
return x
return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output
def get_num_layers(self):
return len(self.blocks)
def mask_model(self, x, mask):
x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)
return x
def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
qkv_bias=True, **kwargs)
return model
def vit_small(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
qkv_bias=True, **kwargs)
return model
def vit_base(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, **kwargs)
return model
def vit_large(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
qkv_bias=True, **kwargs)
return model
================================================
FILE: downstream_tasks/detection/scripts/run_eval.sh
================================================
#!/usr/bin/env bash
echo "EVAL MODEL:"$MODEL
python -m torch.distributed.launch --nproc_per_node=8 \
evaluation/object_detection/test.py \
$CONFIG \
$MODEL \
--launcher pytorch \
--eval bbox segm \
--cfg-options model.backbone.use_checkpoint=True \
${@:6}
================================================
FILE: downstream_tasks/detection/scripts/run_train_maskrcnn_vit_base.sh
================================================
#!/usr/bin/env bash
python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=$PORT \
evaluation/object_detection/train.py \
evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \
--launcher pytorch \
--work-dir $OUTPUT_DIR \
--no-validate \
--deterministic \
--cfg-options model.backbone.use_checkpoint=True \
model.pretrained=$PRETRAINED \
${@:6}
================================================
FILE: downstream_tasks/detection/scripts/run_train_maskrcnn_vit_large.sh
================================================
#!/usr/bin/env bash
python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=$PORT \
evaluation/object_detection/train.py \
evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \
--launcher pytorch \
--work-dir $OUTPUT_DIR \
--no-validate \
--deterministic \
--cfg-options model.backbone.use_checkpoint=True \
model.pretrained=$PRETRAINED \
${@:6}
================================================
FILE: downstream_tasks/detection/utils.py
================================================
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Mostly copy-paste from torchvision references or other public repos like DETR:
https://github.com/facebookresearch/detr/blob/master/util/misc.py
"""
import os
import sys
import time
import math
import json
import random
import datetime
import subprocess
import numpy as np
import torch
import torch.distributed as dist
from collections import defaultdict, deque
from pathlib import Path
from torch import nn
from PIL import ImageFilter, ImageOps, Image, ImageDraw
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class PermutePatch(object):
"""
Apply Patch permutation to the PIL image.
"""
def __init__(self, psz):
self.psz = psz
def __call__(self, img):
imgs = []
imgwidth, imgheight = img.size
for i in range(0, imgheight, self.psz):
for j in range(0, imgwidth, self.psz):
box = (j, i, j+self.psz, i+self.psz)
imgs.append(img.crop(box))
random.shuffle(imgs)
new_img = Image.new('RGB', (imgwidth, imgheight))
k = 0
for i in range(0, imgheight, self.psz):
for j in range(0, imgwidth, self.psz):
new_img.paste(imgs[k], (j, i))
k += 1
return new_img
class HideAndSeek(object):
"""
Apply Patch permutation to the PIL image.
"""
def __init__(self, ratio, psz):
self.ratio = ratio
self.psz = psz
def __call__(self, img):
imgwidth, imgheight = img.size
numw, numh = imgwidth // self.psz, imgheight // self.psz
mask_num = int(numw * numh * self.ratio)
mask_patch = np.random.choice(np.arange(numw * numh), mask_num, replace=False)
mask_w, mask_h = mask_patch % numh, mask_patch // numh
# img.save('test1.png')
draw = ImageDraw.Draw(img)
for mw, mh in zip(mask_w, mask_h):
draw.rectangle((mw * self.psz,
mh * self.psz,
(mw + 1) * self.psz,
(mh + 1) * self.psz), fill="black")
# img.save('test2.png')
return img
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
if os.path.isfile(pretrained_weights):
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
print(f"Take key {checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
return
elif pretrained_weights == 'download':
url = None
if model_name == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif model_name == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
elif model_name == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif model_name == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print("Since no pretrained weights are provided, we load the pretrained weights from {}.".format(url))
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
return
elif pretrained_weights == 'supervised':
url = None
if model_name == "vit_small" and patch_size == 16:
url = "deit_small_patch16_224-cd65a155.pth"
elif model_name == "vit_base" and patch_size == 16:
url = "deit_base_patch16_224-b5f2ef4d.pth"
if url is not None:
print("Since no pretrained weights are provided, we load the pretrained weights from {}.".format(url))
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/" + url)
msg = model.load_state_dict(state_dict['model'], strict=False)
print('Supervised weights found at {} and loaded with msg: {}'.format(url, msg))
return
print("There is no reference weights available for this model => We use random weights.")
def clip_gradients(model, clip):
norms = []
for name, p in model.named_parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
norms.append(param_norm.item())
clip_coef = clip / (param_norm + 1e-6)
if clip_coef < 1:
p.grad.data.mul_(clip_coef)
return norms
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
if epoch >= freeze_last_layer:
return
for n, p in model.named_parameters():
if "last_layer" in n:
p.grad = None
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
if not os.path.isfile(ckp_path):
return
print("Found checkpoint at {}".format(ckp_path))
# open checkpoint file
checkpoint = torch.load(ckp_path, map_location="cpu")
# key is what to look for in the checkpoint file
# value is the object to load
# example: {'state_dict': model}
for key, value in kwargs.items():
if key in checkpoint and value is not None:
try:
msg = value.load_state_dict(checkpoint[key], strict=False)
print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
except TypeError:
try:
msg = value.load_state_dict(checkpoint[key])
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
except ValueError:
print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
else:
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
def fix_random_seeds(seed=31):
"""
Fix random seeds.
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.6f} ({global_avg:.6f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.6f}')
data_time = SmoothedValue(fmt='{avg:.6f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.6f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def init_distributed_mode(args):
# launched with torch.distributed.launch
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
# launched with submitit on a slurm cluster
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
# launched naively with `python main_dino.py`
# we manually add MASTER_ADDR and MASTER_PORT to env variables
elif torch.cuda.is_available():
print('Will run the code on one GPU.')
args.rank, args.gpu, args.world_size = 0, 0, 1
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
else:
print('Does not support training without GPU.')
sys.exit(1)
dist.init_process_group(
backend="nccl",
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.cuda.set_device(args.gpu)
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
dist.barrier()
setup_for_distributed(args.rank == 0)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
class LARS(torch.optim.Optimizer):
"""
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
"""
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=None, lars_adaptation_filter=None):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if p.ndim != 1:
dp = dp.add(p, alpha=g['weight_decay'])
if p.ndim != 1:
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
0.9,
0.999
],
"eps": 1e-8
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 7,
"loss_scale_window": 128
}
}
writer.write(json.dumps(ds_config, indent=2))
class MultiCropWrapper(nn.Module):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def __init__(self, backbone, head=None):
super(MultiCropWrapper, self).__init__()
# disable layers dedicated to ImageNet labels classification
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
self.backbone = backbone
if head is None:
self.head = nn.Identity()
else:
self.head = head
def forward(self, x, mask=None, return_backbone_feat=False,
**kwargs):
# convert to list
if not isinstance(x, list):
x = [x]
mask = [mask] if mask is not None else None
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
inp_x = torch.cat(x[start_idx: end_idx])
if mask is not None:
inp_m = torch.cat(mask[start_idx: end_idx])
kwargs.update(dict(mask=inp_m))
_out = self.backbone(inp_x, **kwargs)
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
# Run the head forward on the concatenated features.
output_ = self.head(output)
if return_backbone_feat:
return output, output_
return output_
def get_params_groups(model):
regularized = []
not_regularized = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if name.endswith(".bias") or len(param.shape) == 1:
not_regularized.append(param)
else:
regularized.append(param)
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
def has_batchnorms(model):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
for name, module in model.named_modules():
if isinstance(module, bn_types):
return True
return False
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
class PCA():
"""
Class to compute and apply PCA.
"""
def __init__(self, dim=256, whit=0.5):
self.dim = dim
self.whit = whit
self.mean = None
def train_pca(self, cov):
"""
Takes a covariance matrix (np.ndarray) as input.
"""
d, v = np.linalg.eigh(cov)
eps = d.max() * 1e-5
n_0 = (d < eps).sum()
if n_0 > 0:
d[d < eps] = eps
# total energy
totenergy = d.sum()
# sort eigenvectors with eigenvalues order
idx = np.argsort(d)[::-1][:self.dim]
d = d[idx]
v = v[:, idx]
print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
# for the whitening
d = np.diag(1. / d**self.whit)
# principal components
self.dvt = np.dot(d, v.T)
def apply(self, x):
# input is from numpy
if isinstance(x, np.ndarray):
if self.mean is not None:
x -= self.mean
return np.dot(self.dvt, x.T).T
# input is from torch and is on GPU
if x.is_cuda:
if self.mean is not None:
x -= torch.cuda.FloatTensor(self.mean)
return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
# input if from torch, on CPU
if self.mean is not None:
x -= torch.FloatTensor(self.mean)
return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
def compute_ap(ranks, nres):
"""
Computes average precision for given ranked indexes.
Arguments
---------
ranks : zerro-based ranks of positive images
nres : number of positive images
Returns
-------
ap : average precision
"""
# number of images ranked by the system
nimgranks = len(ranks)
# accumulate trapezoids in PR-plot
ap = 0
recall_step = 1. / nres
for j in np.arange(nimgranks):
rank = ranks[j]
if rank == 0:
precision_0 = 1.
else:
precision_0 = float(j) / rank
precision_1 = float(j + 1) / (rank + 1)
ap += (precision_0 + precision_1) * recall_step / 2.
return ap
def compute_map(ranks, gnd, kappas=[]):
"""
Computes the mAP for a given set of returned results.
Usage:
map = compute_map (ranks, gnd)
computes mean average precsion (map) only
map, aps, pr, prs = compute_map (ranks, gnd, kappas)
computes mean average precision (map), average precision (aps) for each query
computes mean precision at kappas (pr), precision at kappas (prs) for each query
Notes:
1) ranks starts from 0, ranks.shape = db_size X #queries
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
3) If there are no positive images for some query, that query is excluded from the evaluation
"""
map = 0.
nq = len(gnd) # number of queries
aps = np.zeros(nq)
pr = np.zeros(len(kappas))
prs = np.zeros((nq, len(kappas)))
nempty = 0
for i in np.arange(nq):
qgnd = np.array(gnd[i]['ok'])
# no positive images, skip from the average
if qgnd.shape[0] == 0:
aps[i] = float('nan')
prs[i, :] = float('nan')
nempty += 1
continue
try:
qgndj = np.array(gnd[i]['junk'])
except:
qgndj = np.empty(0)
# sorted positions of positive and junk images (0 based)
pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
k = 0;
ij = 0;
if len(junk):
# decrease positions of positives based on the number of
# junk images appearing before them
ip = 0
while (ip < len(pos)):
while (ij < len(junk) and pos[ip] > junk[ij]):
k += 1
ij += 1
pos[ip] = pos[ip] - k
ip += 1
# compute ap
ap = compute_ap(pos, len(qgnd))
map = map + ap
aps[i] = ap
# compute precision @ k
pos += 1 # get it to 1-based
for j in np.arange(len(kappas)):
kq = min(max(pos), kappas[j]);
prs[i, j] = (pos <= kq).sum() / kq
pr = pr + prs[i, :]
map = map / (nq - nempty)
pr = pr / (nq - nempty)
return map, aps, pr, prs
================================================
FILE: downstream_tasks/semantic_segmentation/README.md
================================================
# ADE20k Semantic segmentation with CAE
## Getting started
1. Install the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library and some required packages.
```bash
pip install mmcv-full==1.3.0 mmsegmentation==0.11.0
pip install scipy timm==0.3.2
```
2. Install [apex](https://github.com/NVIDIA/apex) for mixed-precision training
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
3. Follow the guide in [mmseg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/dataset_prepare.md) to prepare the ADE20k dataset.
## Fine-tuning
Command format:
```
tools/dist_train.sh --work-dir --seed 0 --deterministic --options model.pretrained=
```
For example, using a CAE-base backbone with UperNet:
```bash
bash tools/dist_train.sh \
configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py 8 \
--work-dir /path/to/save --seed 0 --deterministic \
--options model.pretrained=
```
More config files can be found at [`configs_local/cae/upernet`](configs_local/cae/upernet).
## Evaluation
Command format:
```
tools/dist_test.sh --eval mIoU
```
For example, evaluate a CAE-base backbone with UperNet:
```bash
bash tools/dist_test.sh configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py \
8 --eval mIoU
```
Please note that, the evaluation will be automatically conducted during training.
## Results (pretrined models are trained on ImageNet-1K without label)
| Backbone | #Pretrained Epoch | mIoU | Config |
| -------- | ----------------- | ---- | ---------------------------------------- |
| ViT-B | 300 | 48.1 | [3e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py) |
| ViT-B | 800 | 49.7 | [2e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py) |
| ViT-B | 1600 | 50.3 | [1e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py) |
| ViT-L | 1600 | 54.9 | [4e-5](./configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py) |
We find that, if the pretrained model is better, a smaller learning rate is more suitable. However, different learning rates will not lead to significantly different results. For example, 800-epoch pretrained ViT-B could obtain 49.6 mIoU (averaged from two runs) with lr=4e-4.
## Acknowledgment
This code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library, the [Swin](https://github.com/microsoft/Swin-Transformer) repository, [XCiT](https://github.com/facebookresearch/xcit) and the [SETR](https://github.com/fudan-zvg/SETR) repository.
================================================
FILE: downstream_tasks/semantic_segmentation/backbone/beit.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
import numpy as np
from mmcv_custom import load_checkpoint
from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if token:
sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@BACKBONES.register_module()
class BEiT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False,
out_indices=[3, 5, 7, 11], out_with_norm=False):
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_indices = out_indices
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
# self.pos_embed = None
# self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
if use_sincos_pos_embed:
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.out_indices = out_indices
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
if not out_with_norm:
self.norm = nn.Identity()
else:
self.norm = norm_layer(embed_dim)
self.apply(self._init_weights)
self.fix_init_weight()
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
'''
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()
else:
x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()
x = torch.cat([x[:,:1],x],dim=1)
'''
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
if i in self.out_indices:
xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)
# xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
return tuple(features)
def forward(self, x):
x = self.forward_features(x)
return x
================================================
FILE: downstream_tasks/semantic_segmentation/backbone/beit_fapn.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from mmcv_custom import load_checkpoint
from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES
from mmcv.ops import DeformConv2d
from mmcv.cnn import xavier_init
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class FeatureSelectionModule(nn.Module):
def __init__(self, in_c, out_c, norm="GM"):
super(FeatureSelectionModule, self).__init__()
self.conv_attn = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False).cuda(), # without norm and activation
)
self.sigmoid = nn.Sigmoid().cuda()
self.conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False).cuda(),
nn.BatchNorm2d(out_c).cuda(),
nn.ReLU(inplace=True).cuda()
)
xavier_init(self.conv_attn)
for m in self.conv:
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, x):
attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:])))
feat = torch.mul(x, attn)
x = x + feat
feat = self.conv(x)
return feat
class FeatureAlign(nn.Module):
def __init__(self, in_c, out_c, norm=None):
super(FeatureAlign, self).__init__()
self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm="")
self.relu = nn.ReLU(inplace=True).cuda()
self.offset = nn.Conv2d(out_c*2, 144, kernel_size=1, stride=1, padding=0, bias=False).cuda() # 144=kernel_size[0]*kernel_size[1]*deform_groups*2
self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8).cuda()
def forward(self, feat_l, feat_s):
feat_l = feat_l.float()
feat_s = feat_s.float()
HW = feat_l.size()[2:]
if feat_l.size()[2:] != feat_s.size()[2:]:
feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False)
else:
feat_up = feat_s
feat_arm = self.lateral_conv(feat_l)
offset = self.offset(torch.cat([feat_arm, feat_up], dim=1)).float()
feat_align = self.relu(self.deform_conv2d(feat_up, offset))
return feat_align + feat_arm
@BACKBONES.register_module()
class BEiT_FaPN(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
out_indices=[3, 5, 7, 11]):
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.out_conv = []
self.FaPN = []
for i in range(len(out_indices)):
self.out_conv.append(nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=1, padding=1, bias=False).cuda())
self.FaPN.append(FeatureAlign(embed_dim, embed_dim))
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_indices = out_indices
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.out_indices = out_indices
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
self.apply(self._init_weights)
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
if i in self.out_indices:
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
features_fapn = []
features_fapn.append(features[-1])
for i in range(len(ops)-1, 0, -1):
new_feature = self.FaPN[i](features[i-1], features[i])
new_feature = self.out_conv[i](new_feature)
features_fapn.append(new_feature)
features_fapn.reverse()
return tuple(features_fapn)
def forward(self, x):
x = self.forward_features(x)
return x
================================================
FILE: downstream_tasks/semantic_segmentation/backbone/cae.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
import numpy as np
from mmcv_custom import load_checkpoint
from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if token:
sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@BACKBONES.register_module()
class CAE(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False,
out_indices=[3, 5, 7, 11], out_with_norm=False):
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_indices = out_indices
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
# self.pos_embed = None
# self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
if use_sincos_pos_embed:
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.out_indices = out_indices
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
if not out_with_norm:
self.norm = nn.Identity()
else:
self.norm = norm_layer(embed_dim)
self.apply(self._init_weights)
self.fix_init_weight()
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
'''
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()
else:
x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()
x = torch.cat([x[:,:1],x],dim=1)
'''
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
if i in self.out_indices:
xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp)
# xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
return tuple(features)
def forward(self, x):
x = self.forward_features(x)
return x
================================================
FILE: downstream_tasks/semantic_segmentation/backbone/fapn.py
================================================
class FeatureSelectionModule(nn.Module):
def __init__(self, in_c, out_c, norm="GM"):
super(FeatureSelectionModule, self).__init__()
self.conv_attn = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False), # without norm and activation
)
self.sigmoid = nn.Sigmoid()
self.conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
xavier_init(self.conv_attn)
for m in self.conv.modeuls():
if isintance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, x):
attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:])))
feat = torch.mul(x, attn)
x = x + feat
feat = self.conv(x)
return feat
class FeatureAlign(nn.Module):
def __init__(self, in_c, out_c, norm=None):
super(FeatureAlign, self).__init__()
self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm="")
self.relu = nn.ReLU(inplace=True)
self.offset = nn.Conv2d(out_c * 2, out_c, kernel_size=1, stride=1, padding=0, bias=False)
self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8)
def forward(self, teat_l, feat_s):
HW = feat_l.size()[2:]
if feat_l.size()[2:] != feat_s.size()[2:]:
feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False)
else:
feat_up = feat_s
feat_arm = self.lateral_conv(feat_l)
offset = self.offset(torch.cat([feat_arm, feat_up], dim=1))
feat_align = self.relu(self.deform_conv2d(feat_up, offset))
return feat_align + feat_arm
================================================
FILE: downstream_tasks/semantic_segmentation/backbone/mae.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
import numpy as np
from mmcv_custom import load_checkpoint
from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
# if qkv_bias:
# self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
# self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
# else:
# self.q_bias = None
# self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
# if self.q_bias is not None:
# qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
# qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if token:
sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@BACKBONES.register_module()
class MAE(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
out_indices=[3, 5, 7, 11]):
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_indices = out_indices
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
# self.pos_embed = None
# self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.out_indices = out_indices
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
self.apply(self._init_weights)
self.fix_init_weight()
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
'''
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()
else:
x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach()
x = torch.cat([x[:,:1],x],dim=1)
'''
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
if i in self.out_indices:
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
return tuple(features)
def forward(self, x):
x = self.forward_features(x)
return x
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k.py
================================================
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k_640x640.py
================================================
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/chase_db1.py
================================================
# dataset settings
dataset_type = 'ChaseDB1Dataset'
data_root = 'data/CHASE_DB1'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (960, 999)
crop_size = (128, 128)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes.py
================================================
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/train',
ann_dir='gtFine/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes_769x769.py
================================================
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (769, 769)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2049, 1025),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/coco-stuff10k.py
================================================
# dataset settings
dataset_type = 'COCOStuffDataset'
data_root = 'data/coco_stuff10k'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/train2014',
ann_dir='annotations/train2014',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/test2014',
ann_dir='annotations/test2014',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/test2014',
ann_dir='annotations/test2014',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/drive.py
================================================
# dataset settings
dataset_type = 'DRIVEDataset'
data_root = 'data/DRIVE'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (584, 565)
crop_size = (64, 64)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/hrf.py
================================================
# dataset settings
dataset_type = 'HRFDataset'
data_root = 'data/HRF'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (2336, 3504)
crop_size = (256, 256)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_context.py
================================================
# dataset settings
dataset_type = 'PascalContextDataset'
data_root = 'data/VOCdevkit/VOC2010/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (520, 520)
crop_size = (480, 480)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12.py
================================================
# dataset settings
dataset_type = 'PascalVOCDataset'
data_root = 'data/VOCdevkit/VOC2012'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/val.txt',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12_aug.py
================================================
_base_ = './pascal_voc12.py'
# dataset settings
data = dict(
train=dict(
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
split=[
'ImageSets/Segmentation/train.txt',
'ImageSets/Segmentation/aug.txt'
]))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/stare.py
================================================
# dataset settings
dataset_type = 'STAREDataset'
data_root = 'data/STARE'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (605, 700)
crop_size = (128, 128)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/default_runtime.py
================================================
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ann_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ANNHead',
in_channels=[1024, 2048],
in_index=[2, 3],
channels=512,
project_channels=256,
query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/apcnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='APCHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ccnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='CCHead',
in_channels=2048,
in_index=3,
channels=512,
recurrence=2,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/cgnet.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='CGNet',
norm_cfg=norm_cfg,
in_channels=3,
num_channels=(32, 64, 128),
num_blocks=(3, 21),
dilations=(2, 4),
reductions=(8, 16)),
decode_head=dict(
type='FCNHead',
in_channels=256,
in_index=2,
channels=256,
num_convs=0,
concat_input=False,
dropout_ratio=0,
num_classes=19,
norm_cfg=norm_cfg,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
10.396974, 10.055647
])),
# model training and testing settings
train_cfg=dict(sampler=None),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/danet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DAHead',
in_channels=2048,
in_index=3,
channels=512,
pam_channels=64,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_unet_s5-d16.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UNet',
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False),
decode_head=dict(
type='ASPPHead',
in_channels=64,
in_index=4,
channels=16,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=256, stride=170))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3plus_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/dmnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DMHead',
in_channels=2048,
in_index=3,
channels=512,
filter_sizes=(1, 3, 5, 7),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/dnl_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DNLHead',
in_channels=2048,
in_index=3,
channels=512,
dropout_ratio=0.1,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/emanet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='EMAHead',
in_channels=2048,
in_index=3,
channels=256,
ema_channels=512,
num_bases=64,
num_stages=3,
momentum=0.1,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/encnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='EncHead',
in_channels=[512, 1024, 2048],
in_index=(1, 2, 3),
channels=512,
num_codes=32,
use_se_loss=True,
add_lateral=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_se_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fast_scnn.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='FastSCNN',
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
norm_cfg=norm_cfg,
align_corners=False),
decode_head=dict(
type='DepthwiseSeparableFCNHead',
in_channels=128,
channels=128,
concat_input=False,
num_classes=19,
in_index=-1,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=32,
num_convs=1,
num_classes=19,
in_index=-2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=64,
channels=32,
num_convs=1,
num_classes=19,
in_index=-3,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_hr18.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://msra/hrnetv2_w18',
backbone=dict(
type='HRNet',
norm_cfg=norm_cfg,
norm_eval=False,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(18, 36)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(18, 36, 72)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(18, 36, 72, 144)))),
decode_head=dict(
type='FCNHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),
input_transform='resize_concat',
kernel_size=1,
num_convs=1,
concat_input=False,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='FCNHead',
in_channels=2048,
in_index=3,
channels=512,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_unet_s5-d16.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UNet',
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False),
decode_head=dict(
type='FCNHead',
in_channels=64,
in_index=4,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=256, stride=170))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fpn_r50.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/gcnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='GCHead',
in_channels=2048,
in_index=3,
channels=512,
ratio=1 / 4.,
pooling_type='att',
fusion_types=('channel_add', ),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/lraspp_m-v3-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='MobileNetV3',
arch='large',
out_indices=(1, 3, 16),
norm_cfg=norm_cfg),
decode_head=dict(
type='LRASPPHead',
in_channels=(16, 24, 960),
in_index=(0, 1, 2),
channels=128,
input_transform='multiple_select',
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/nonlocal_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='NLHead',
in_channels=2048,
in_index=3,
channels=512,
dropout_ratio=0.1,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_hr18.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='CascadeEncoderDecoder',
num_stages=2,
pretrained='open-mmlab://msra/hrnetv2_w18',
backbone=dict(
type='HRNet',
norm_cfg=norm_cfg,
norm_eval=False,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(18, 36)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(18, 36, 72)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(18, 36, 72, 144)))),
decode_head=[
dict(
type='FCNHead',
in_channels=[18, 36, 72, 144],
channels=sum([18, 36, 72, 144]),
in_index=(0, 1, 2, 3),
input_transform='resize_concat',
kernel_size=1,
num_convs=1,
concat_input=False,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='OCRHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
input_transform='resize_concat',
channels=512,
ocr_channels=256,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='CascadeEncoderDecoder',
num_stages=2,
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=[
dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='OCRHead',
in_channels=2048,
in_index=3,
channels=512,
ocr_channels=256,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pointrend_r50.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='CascadeEncoderDecoder',
num_stages=2,
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=[
dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='PointHead',
in_channels=[256],
in_index=[0],
channels=256,
num_fcs=3,
coarse_pred_each_layer=True,
dropout_ratio=-1,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
],
# model training and testing settings
train_cfg=dict(
num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
test_cfg=dict(
mode='whole',
subdivision_steps=2,
subdivision_num_points=8196,
scale_factor=2))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/psanet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSAHead',
in_channels=2048,
in_index=3,
channels=512,
mask_size=(97, 97),
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_r50-d8.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_unet_s5-d16.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UNet',
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False),
decode_head=dict(
type='PSPHead',
in_channels=64,
in_index=4,
channels=16,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=256, stride=170))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_cae.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='XCiT',
patch_size=16,
embed_dim=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
),
decode_head=dict(
type='UPerHead',
in_channels=[384, 384, 384, 384],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_r50.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_160k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU')
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_20k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=20000)
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_320k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=320000)
checkpoint_config = dict(by_epoch=False, interval=32000)
evaluation = dict(interval=32000, metric='mIoU')
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_40k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=40000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_80k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU')
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/beit/upernet_beit_base_12_512_slide_160k_ade20k_pt_4e-4.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='BEiT',
img_size=512,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_sincos_pos_embed=False,
use_rel_pos_bias=True,
init_values=0.1,
drop_path_rate=0.1,
out_indices=[3, 5, 7, 11]
),
decode_head=dict(
in_channels=[768, 768, 768, 768],
num_classes=150,
channels=768,
),
auxiliary_head=dict(
in_channels=768,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
# paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
# 'relative_position_bias_table': dict(decay_mult=0.),
# 'norm': dict(decay_mult=0.)}))
optimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
#img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
#crop_size = (512, 512)
## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=8000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='CAE',
img_size=512,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=0.1,
drop_path_rate=0.1,
out_indices=[3, 5, 7, 11]
),
decode_head=dict(
in_channels=[768, 768, 768, 768],
num_classes=150,
channels=768,
),
auxiliary_head=dict(
in_channels=768,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
optimizer = dict(_delete_=True, type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=8000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='CAE',
img_size=512,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=0.1,
drop_path_rate=0.1,
out_indices=[3, 5, 7, 11]
),
decode_head=dict(
in_channels=[768, 768, 768, 768],
num_classes=150,
channels=768,
),
auxiliary_head=dict(
in_channels=768,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
optimizer = dict(_delete_=True, type='AdamW', lr=2e-4, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=8000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='CAE',
img_size=512,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=0.1,
drop_path_rate=0.1,
out_indices=[3, 5, 7, 11]
),
decode_head=dict(
in_channels=[768, 768, 768, 768],
num_classes=150,
channels=768,
),
auxiliary_head=dict(
in_channels=768,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
optimizer = dict(_delete_=True, type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=8000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='CAE',
img_size=512,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=1e-5,
drop_path_rate=0.15,
out_indices=[7, 11, 15, 23],
),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024],
num_classes=150,
channels=1024,
),
auxiliary_head=dict(
in_channels=1024,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
# paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
# 'relative_position_bias_table': dict(decay_mult=0.),
# 'norm': dict(decay_mult=0.)}))
optimizer = dict(_delete_=True, type='AdamW', lr=4e-5, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
#img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
#crop_size = (512, 512)
## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=32000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/configs_local/mae/upernet_mae_large_12_512_slide_160k_ade20k_pt_4e-4.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
_base_ = [
'../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py',
'../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
backbone=dict(
type='MAE',
img_size=512,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=1,
drop_path_rate=0.2,
out_indices=[7, 11, 15, 23],
),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024],
num_classes=150,
channels=1024,
),
auxiliary_head=dict(
in_channels=1024,
num_classes=150
),
test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
)
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
# optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
# paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
# 'relative_position_bias_table': dict(decay_mult=0.),
# 'norm': dict(decay_mult=0.)}))
optimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.65))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)
#img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
#crop_size = (512, 512)
## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341))
find_unused_parameters = True
#test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# flip=True,
# transforms=[
# dict(type='SETR_Resize', keep_ratio=True,
# crop_size=crop_size, setr_multi_scale=True),
# dict(type='RandomFlip'),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
#]
#data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline),
# samples_per_gpu=2,
#)
runner = dict(type='IterBasedRunnerAmp')
checkpoint_config = dict(by_epoch=False, interval=32000)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/__init__.py
================================================
# -*- coding: utf-8 -*-
from .checkpoint import load_checkpoint
from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor
from .resize_transform import SETR_Resize
from .apex_runner.optimizer import DistOptimizerHook
from .train_api import train_segmentor
__all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor']
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/__init__.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import save_checkpoint
from .apex_iter_based_runner import IterBasedRunnerAmp
__all__ = [
'save_checkpoint', 'IterBasedRunnerAmp',
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/apex_iter_based_runner.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.runner import RUNNERS, IterBasedRunner
from .checkpoint import save_checkpoint
try:
import apex
except:
print('apex is not installed')
@RUNNERS.register_module()
class IterBasedRunnerAmp(IterBasedRunner):
"""Iteration-based Runner with AMP support.
This runner train models iteration by iteration.
"""
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=False):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if meta is None:
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
elif isinstance(meta, dict):
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
# if create_symlink:
# dst_file = osp.join(out_dir, 'latest.pth')
# if platform.system() != 'Windows':
# mmcv.symlink(filename, dst_file)
# else:
# shutil.copy(filepath, dst_file)
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self._inner_iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
if 'amp' in checkpoint:
apex.amp.load_state_dict(checkpoint['amp'])
self.logger.info('load amp state dict')
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/checkpoint.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
from tempfile import TemporaryDirectory
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
try:
import apex
except:
print('apex is not installed')
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
``optimizer``, ``amp``. By default ``meta`` will contain version
and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# save amp state dict in the checkpoint
checkpoint['amp'] = apex.amp.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/optimizer.py
================================================
from mmcv.runner import OptimizerHook, HOOKS
try:
import apex
except:
print('apex is not installed')
@HOOKS.register_module()
class DistOptimizerHook(OptimizerHook):
"""Optimizer hook for distributed training."""
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.update_interval = update_interval
self.use_fp16 = use_fp16
def before_run(self, runner):
runner.optimizer.zero_grad()
def after_train_iter(self, runner):
runner.outputs['loss'] /= self.update_interval
if self.use_fp16:
with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:
scaled_loss.backward()
else:
runner.outputs['loss'].backward()
if self.every_n_iters(runner, self.update_interval):
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
runner.optimizer.zero_grad()
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo
from torch.nn import functional as F
import mmcv
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.utils import mkdir_or_exist
from mmcv.runner import get_dist_info
from scipy import interpolate
import numpy as np
import math
ENV_MMCV_HOME = 'MMCV_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url, model_dir=None, map_location="cpu"):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array(
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'module' in checkpoint:
state_dict = checkpoint['module']
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
all_keys = list(state_dict.keys())
print("origin keys:", len(all_keys), all_keys)
if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'):
# NOTE: remove all decoder keys
all_keys = [key for key in all_keys if key.startswith('encoder.')]
print("all keys:", all_keys)
for key in all_keys:
new_key = key.replace('encoder.','')
# print("new_key:", new_key)
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
for key in list(state_dict.keys()):
if key.startswith('decoder.'):
# print("key:", key)
state_dict.pop(key)
# NOTE: replace norm with fc_norm
for key in list(state_dict.keys()):
# print("new key:", key)
if key.startswith('norm.'):
new_key = key.replace('norm.','fc_norm.')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
print("new keys:", len(state_dict), state_dict.keys())
# reshape absolute position embedding for Swin
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H*W:
logger.warning("Error in loading absolute_pos_embed, pass")
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
rank, _ = get_dist_info()
if "rel_pos_bias.relative_position_bias_table" in state_dict:
if rank == 0:
print("Expand the shared relative position embedding to each layers. ")
num_layers = model.get_num_layers()
rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
state_dict.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(state_dict.keys())
# for moco
for key in all_keys:
if 'base_encoder.' in key:
new_key = key.replace('base_encoder.', '')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
if 'momentum_encoder' in key:
state_dict.pop(key)
# for ibot
for key in all_keys:
if 'module.backbone.' in key:
new_key = key.replace('module.backbone.', '')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
elif 'backbone.' in key:
new_key = key.replace('backbone.', '')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
for key in all_keys:
if "relative_position_index" in key:
state_dict.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
if rank == 0:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.13492:
# q = 1.13492
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
if rank == 0:
print("x = {}".format(x))
print("dx = {}".format(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
if 'pos_embed' in state_dict: #and model.use_abs_pos_emb:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
if rank == 0:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
# interpolate position bias table if needed
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {table_key}, pass")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2), mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint_beit.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo
from torch.nn import functional as F
import mmcv
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.utils import mkdir_or_exist
from mmcv.runner import get_dist_info
from scipy import interpolate
import numpy as np
import math
ENV_MMCV_HOME = 'MMCV_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url, model_dir=None, map_location="cpu"):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array(
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'module' in checkpoint:
state_dict = checkpoint['module']
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
# reshape absolute position embedding for Swin
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H*W:
logger.warning("Error in loading absolute_pos_embed, pass")
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
rank, _ = get_dist_info()
if "rel_pos_bias.relative_position_bias_table" in state_dict:
if rank == 0:
print("Expand the shared relative position embedding to each layers. ")
num_layers = model.get_num_layers()
rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
state_dict.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(state_dict.keys())
all_keys = sorted(all_keys)
print("origin keys:", len(all_keys), all_keys)
if all_keys[-2].startswith('encoder_to_decoder'):
# NOTE: remove all decoder keys
all_keys = [key for key in all_keys if key.startswith('encoder.')]
print("all keys:", all_keys)
for key in all_keys:
new_key = key.replace('encoder.','')
# print("new_key:", new_key)
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
for key in list(state_dict.keys()):
if key.startswith('decoder.'):
# print("key:", key)
state_dict.pop(key)
# NOTE: replace norm with fc_norm
for key in list(state_dict.keys()):
# print("new key:", key)
if key.startswith('norm.'):
new_key = key.replace('norm.','fc_norm.')
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
print("new keys:", len(state_dict), state_dict.keys())
for key in all_keys:
if "relative_position_index" in key:
state_dict.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
if rank == 0:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.13492:
# q = 1.13492
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
if rank == 0:
print("x = {}".format(x))
print("dx = {}".format(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
if rank == 0:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
# interpolate position bias table if needed
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {table_key}, pass")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2), mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py
================================================
import json
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
from mmcv.runner import get_dist_info
def get_num_layer_for_vit(var_name, num_max_layer):
if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
return 0
elif var_name.startswith("backbone.patch_embed"):
return 0
elif var_name.startswith("backbone.blocks"):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return num_max_layer - 1
@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
parameter_groups = {}
print(self.paramwise_cfg)
num_layers = self.paramwise_cfg.get('num_layers') + 2
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers))
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
layer_id = get_num_layer_for_vit(name, num_layers)
group_name = "layer_%d_%s" % (layer_id, group_name)
if group_name not in parameter_groups:
scale = layer_decay_rate ** (num_layers - layer_id - 1)
parameter_groups[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"param_names": [],
"lr_scale": scale,
"group_name": group_name,
"lr": scale * self.base_lr,
}
parameter_groups[group_name]["params"].append(param)
parameter_groups[group_name]["param_names"].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
"param_names": parameter_groups[key]["param_names"],
"lr_scale": parameter_groups[key]["lr_scale"],
"lr": parameter_groups[key]["lr"],
"weight_decay": parameter_groups[key]["weight_decay"],
}
print("Param groups = %s" % json.dumps(to_display, indent=2))
# state_dict = module.state_dict()
# for group_name in parameter_groups:
# group = parameter_groups[group_name]
# for name in group["param_names"]:
# group["params"].append(state_dict[name])
params.extend(parameter_groups.values())
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/resize_transform.py
================================================
import mmcv
import numpy as np
from mmseg.datasets.builder import PIPELINES
@PIPELINES.register_module()
class SETR_Resize(object):
"""Resize images & seg.
This transform resizes the input image to some scale. If the input dict
contains the key "scale", then the scale in the input dict is used,
otherwise the specified scale in the init method is used.
``img_scale`` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
- ``ratio_range is not None``: randomly sample a ratio from the ratio range
and multiply it with the image scale.
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
scale from the a range.
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
scale from multiple scales.
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
crop_size=None,
setr_multi_scale=False):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
# assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given a scale and a range of image ratio
assert len(self.img_scale) == 1
else:
# mode 2: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.crop_size = crop_size
self.setr_multi_scale = setr_multi_scale
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
``img_scale`` is sampled scale and None is just a placeholder
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
``scale`` is sampled ratio multiplied with ``img_scale`` and
None is just a placeholder to be consistent with
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if self.keep_ratio:
if self.setr_multi_scale:
if min(results['scale']) < self.crop_size[0]:
new_short = self.crop_size[0]
else:
new_short = min(results['scale'])
h, w = results['img'].shape[:2]
if h > w:
new_h, new_w = new_short * h / w, new_short
else:
new_h, new_w = new_short, new_short * w / h
results['scale'] = (new_h, new_w)
img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'], results['scale'], return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape # in case that there is no padding
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(img_scale={self.img_scale}, '
f'multiscale_mode={self.multiscale_mode}, '
f'ratio_range={self.ratio_range}, '
f'keep_ratio={self.keep_ratio})')
return repr_str
================================================
FILE: downstream_tasks/semantic_segmentation/mmcv_custom/train_api.py
================================================
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import build_optimizer, build_runner
from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
try:
import apex
except:
print('apex is not installed')
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_segmentor(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Launch segmentor training."""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
drop_last=True) for ds in dataset
]
# build optimizer
optimizer = build_optimizer(model, cfg.optimizer)
# use apex fp16 optimizer
if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
if cfg.optimizer_config.get("use_fp16", False):
model, optimizer = apex.amp.initialize(
model.cuda(), optimizer, opt_level="O1")
for m in model.modules():
if hasattr(m, "fp16_enabled"):
m.fp16_enabled = True
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
if cfg.get('runner') is None:
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# register hooks
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = 'IterBasedRunner' not in cfg.runner['type']
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/__init__.py
================================================
import mmcv
from .version import __version__, version_info
MMCV_MIN = '1.1.4'
MMCV_MAX = '1.3.0'
def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
mmcv_min_version = digit_version(MMCV_MIN)
mmcv_max_version = digit_version(MMCV_MAX)
mmcv_version = digit_version(mmcv.__version__)
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
__all__ = ['__version__', 'version_info']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/apis/__init__.py
================================================
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .train import get_root_logger, set_random_seed, train_segmentor
__all__ = [
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
'show_result_pyplot'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/apis/inference.py
================================================
import matplotlib.pyplot as plt
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
def init_segmentor(config, checkpoint=None, device='cuda:0'):
"""Initialize a segmentor from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
Use 'cpu' for loading model on CPU.
Returns:
nn.Module: The constructed segmentor.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
config.model.train_cfg = None
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE']
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
class LoadImage:
"""A simple pipeline to load image."""
def __call__(self, results):
"""Call function to load images into results.
Args:
results (dict): A result dict contains the file name
of the image to be read.
Returns:
dict: ``results`` will be returned containing loaded image.
"""
if isinstance(results['img'], str):
results['filename'] = results['img']
results['ori_filename'] = results['img']
else:
results['filename'] = None
results['ori_filename'] = None
img = mmcv.imread(results['img'])
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
def inference_segmentor(model, img):
"""Inference image(s) with the segmentor.
Args:
model (nn.Module): The loaded segmentor.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
(list[Tensor]): The segmentation result.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, palette=palette, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/apis/test.py
================================================
import os.path as osp
import pickle
import shutil
import tempfile
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
def np2tmp(array, temp_file_name=None):
"""Save ndarray to local numpy file.
Args:
array (ndarray): Ndarray to save.
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
Returns:
str: The numpy file name.
"""
if temp_file_name is None:
temp_file_name = tempfile.NamedTemporaryFile(
suffix='.npy', delete=False).name
np.save(temp_file_name, array)
return temp_file_name
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
efficient_test=False):
"""Test with single GPU.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
show (bool): Whether show results during infernece. Default: False.
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
if show or out_dir:
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result,
palette=dataset.PALETTE,
show=show,
out_file=out_file)
if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result)
results.append(result)
batch_size = data['img'][0].size(0)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result)
results.append(result)
if rank == 0:
batch_size = data['img'][0].size(0)
for _ in range(batch_size * world_size):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
def collect_results_cpu(result_part, size, tmpdir=None):
"""Collect results with CPU."""
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
tmpdir = tempfile.mkdtemp()
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def collect_results_gpu(result_part, size):
"""Collect results with GPU."""
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/apis/train.py
================================================
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import build_optimizer, build_runner
from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_segmentor(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Launch segmentor training."""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
drop_last=True) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if cfg.get('runner') is None:
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# register hooks
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/__init__.py
================================================
from .evaluation import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/__init__.py
================================================
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_iou
__all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/class_names.py
================================================
import mmcv
def cityscapes_classes():
"""Cityscapes class names for external use."""
return [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'
]
def ade_classes():
"""ADE20K class names for external use."""
return [
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag'
]
def voc_classes():
"""Pascal VOC class names for external use."""
return [
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor'
]
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
def ade_palette():
"""ADE20K palette for external use."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
def voc_palette():
"""Pascal VOC palette for external use."""
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
}
def get_classes(dataset):
"""Get class names of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name
if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_classes()')
else:
raise ValueError(f'Unrecognized dataset: {dataset}')
else:
raise TypeError(f'dataset must a str, but got {type(dataset)}')
return labels
def get_palette(dataset):
"""Get class palette (RGB) of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name
if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_palette()')
else:
raise ValueError(f'Unrecognized dataset: {dataset}')
else:
raise TypeError(f'dataset must a str, but got {type(dataset)}')
return labels
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/eval_hooks.py
================================================
import os.path as osp
from mmcv.runner import Hook
from torch.utils.data import DataLoader
class EvalHook(Hook):
"""Evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
"""
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got '
f'{type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.by_epoch = by_epoch
self.eval_kwargs = eval_kwargs
def after_train_iter(self, runner):
"""After train epoch hook."""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook."""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
def evaluate(self, runner, results):
"""Call evaluate function of dataset."""
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
"""
def __init__(self,
dataloader,
interval=1,
gpu_collect=False,
by_epoch=False,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(
'dataloader must be a pytorch DataLoader, but got {}'.format(
type(dataloader)))
self.dataloader = dataloader
self.interval = interval
self.gpu_collect = gpu_collect
self.by_epoch = by_epoch
self.eval_kwargs = eval_kwargs
def after_train_iter(self, runner):
"""After train epoch hook."""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook."""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/metrics.py
================================================
import mmcv
import numpy as np
def intersect_and_union(pred_label,
label,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate intersection and Union.
Args:
pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
work only when label is str. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. The parameter will
work only when label is str. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
if isinstance(pred_label, str):
pred_label = np.load(pred_label)
if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
def total_intersect_and_union(results,
gt_seg_maps,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index, label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label
def mean_iou(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
"""
all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mIoU'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou
def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category dice, shape (num_classes, ).
"""
all_acc, acc, dice = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mDice'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice
def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
metrics=['mIoU'],
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes, ignore_index,
label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
ret_metrics.append(iou)
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
]
return ret_metrics
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/__init__.py
================================================
from .builder import build_pixel_sampler
from .sampler import BasePixelSampler, OHEMPixelSampler
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/builder.py
================================================
from mmcv.utils import Registry, build_from_cfg
PIXEL_SAMPLERS = Registry('pixel sampler')
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/__init__.py
================================================
from .base_pixel_sampler import BasePixelSampler
from .ohem_pixel_sampler import OHEMPixelSampler
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/base_pixel_sampler.py
================================================
from abc import ABCMeta, abstractmethod
class BasePixelSampler(metaclass=ABCMeta):
"""Base class of pixel sampler."""
def __init__(self, **kwargs):
pass
@abstractmethod
def sample(self, seg_logit, seg_label):
"""Placeholder for sample function."""
pass
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py
================================================
import torch
import torch.nn.functional as F
from ..builder import PIXEL_SAMPLERS
from .base_pixel_sampler import BasePixelSampler
@PIXEL_SAMPLERS.register_module()
class OHEMPixelSampler(BasePixelSampler):
"""Online Hard Example Mining Sampler for segmentation.
Args:
context (nn.Module): The context of sampler, subclass of
:obj:`BaseDecodeHead`.
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: None.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
"""
def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__()
self.context = context
assert min_kept > 1
self.thresh = thresh
self.min_kept = min_kept
def sample(self, seg_logit, seg_label):
"""Sample pixels that have high loss or with low prediction confidence.
Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
Returns:
torch.Tensor: segmentation weight, shape (N, H, W)
"""
with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0)
valid_mask = seg_label != self.context.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1)
tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob[valid_mask].sort()
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
seg_weight[valid_mask] = valid_seg_weight
return seg_weight
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/utils/__init__.py
================================================
from .misc import add_prefix
__all__ = ['add_prefix']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/core/utils/misc.py
================================================
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/__init__.py
================================================
from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .custom import CustomDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .pascal_context import PascalContextDataset
from .stare import STAREDataset
from .voc import PascalVOCDataset
from .coco_stuff import COCOStuffDataset
__all__ = [
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'COCOStuffDataset',
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/ade.py
================================================
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class ADE20KDataset(CustomDataset):
"""ADE20K dataset.
In segmentation map annotation for ADE20K, 0 stands for background, which
is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
'.png'.
"""
CLASSES = (
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag')
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
def __init__(self, **kwargs):
super(ADE20KDataset, self).__init__(
img_suffix='.jpg',
seg_map_suffix='.png',
reduce_zero_label=True,
**kwargs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/builder.py
================================================
import copy
import platform
import random
from functools import partial
import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
from torch.utils.data import DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
hard_limit = rlimit[1]
soft_limit = min(4096, hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
def _concat_dataset(cfg, default_args=None):
"""Build :obj:`ConcatDataset by."""
from .dataset_wrappers import ConcatDataset
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
else:
num_ann_dir = 0
if split is not None:
num_split = len(split) if isinstance(split, (list, tuple)) else 1
else:
num_split = 0
if num_img_dir > 1:
assert num_img_dir == num_ann_dir or num_ann_dir == 0
assert num_img_dir == num_split or num_split == 0
else:
assert num_split == num_ann_dir or num_ann_dir <= 1
num_dset = max(num_split, num_img_dir)
datasets = []
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
if isinstance(img_dir, (list, tuple)):
data_cfg['img_dir'] = img_dir[i]
if isinstance(ann_dir, (list, tuple)):
data_cfg['ann_dir'] = ann_dir[i]
if isinstance(split, (list, tuple)):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets)
def build_dataset(cfg, default_args=None):
"""Build datasets."""
from .dataset_wrappers import ConcatDataset, RepeatDataset
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
drop_last=False,
pin_memory=True,
dataloader_type='PoolDataLoader',
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int | None): Seed to be used. Default: None.
drop_last (bool): Whether to drop the last incomplete batch in epoch.
Default: False
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
assert dataloader_type in (
'DataLoader',
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
if dataloader_type == 'PoolDataLoader':
dataloader = PoolDataLoader
elif dataloader_type == 'DataLoader':
dataloader = DataLoader
data_loader = dataloader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
"""Worker init func for dataloader.
The seed of each worker equals to num_worker * rank + worker_id + user_seed
Args:
worker_id (int): Worker id.
num_workers (int): Number of workers.
rank (int): The rank of current process.
seed (int): The random seed to use.
"""
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/chase_db1.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class ChaseDB1Dataset(CustomDataset):
"""Chase_db1 dataset.
In segmentation map annotation for Chase_db1, 0 stands for background,
which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_1stHO.png'.
"""
CLASSES = ('background', 'vessel')
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(ChaseDB1Dataset, self).__init__(
img_suffix='.png',
seg_map_suffix='_1stHO.png',
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/cityscapes.py
================================================
import os.path as osp
import tempfile
import mmcv
import numpy as np
from mmcv.utils import print_log
from PIL import Image
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class CityscapesDataset(CustomDataset):
"""Cityscapes dataset.
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
"""
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
def __init__(self, **kwargs):
super(CityscapesDataset, self).__init__(
img_suffix='_leftImg8bit.png',
seg_map_suffix='_gtFine_labelTrainIds.png',
**kwargs)
@staticmethod
def _convert_to_label_id(result):
"""Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items():
result_copy[result == trainId] = label.id
return result_copy
def results2img(self, results, imgfile_prefix, to_label_id):
"""Write the segmentation results to images.
Args:
results (list[list | tuple | ndarray]): Testing results of the
dataset.
imgfile_prefix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
prog_bar = mmcv.ProgressBar(len(self))
for idx in range(len(self)):
result = results[idx]
if to_label_id:
result = self._convert_to_label_id(result)
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
output = Image.fromarray(result.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
for label_id, label in CSLabels.id2label.items():
palette[label_id] = label.color
output.putpalette(palette)
output.save(png_filename)
result_files.append(png_filename)
prog_bar.update()
return result_files
def format_results(self, results, imgfile_prefix=None, to_label_id=True):
"""Format the results into dir (standard format for Cityscapes
evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str | None): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Default: None.
to_label_id (bool): whether convert output to label_id for
submission. Default: False
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
if imgfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
imgfile_prefix = tmp_dir.name
else:
tmp_dir = None
result_files = self.results2img(results, imgfile_prefix, to_label_id)
return result_files, tmp_dir
def evaluate(self,
results,
metric='mIoU',
logger=None,
imgfile_prefix=None,
efficient_test=False):
"""Evaluation in Cityscapes/default protocol.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
imgfile_prefix (str | None): The prefix of output image file,
for cityscapes evaluation only. It includes the file path and
the prefix of filename, e.g., "a/b/prefix".
If results are evaluated with cityscapes protocol, it would be
the prefix of output png files. The output files would be
png images under folder "a/b/prefix/xxx.png", where "xxx" is
the image name of cityscapes. If not specified, a temp file
will be created for evaluation.
Default: None.
Returns:
dict[str, float]: Cityscapes/default metrics.
"""
eval_results = dict()
metrics = metric.copy() if isinstance(metric, list) else [metric]
if 'cityscapes' in metrics:
eval_results.update(
self._evaluate_cityscapes(results, logger, imgfile_prefix))
metrics.remove('cityscapes')
if len(metrics) > 0:
eval_results.update(
super(CityscapesDataset,
self).evaluate(results, metrics, logger, efficient_test))
return eval_results
def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
"""Evaluation in Cityscapes protocol.
Args:
results (list): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
imgfile_prefix (str | None): The prefix of output image file
Returns:
dict[str: float]: Cityscapes evaluation results.
"""
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
result_files, tmp_dir = self.format_results(results, imgfile_prefix)
if tmp_dir is None:
result_dir = imgfile_prefix
else:
result_dir = tmp_dir.name
eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
CSEval.args.evalInstLevelScore = True
CSEval.args.predictionPath = osp.abspath(result_dir)
CSEval.args.evalPixelAccuracy = True
CSEval.args.JSONOutput = False
seg_map_list = []
pred_list = []
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
for seg_map in mmcv.scandir(
self.ann_dir, 'gtFine_labelIds.png', recursive=True):
seg_map_list.append(osp.join(self.ann_dir, seg_map))
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/coco_stuff.py
================================================
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class COCOStuffDataset(CustomDataset):
"""COCO-Stuff dataset.
In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version
are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff
164k is from 0 to 170, where 255 is the ignore index. So, they are all 171
semantic categories. ``reduce_zero_label`` is set to True and False for the
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
and ``seg_map_suffix`` is fixed to '.png'.
"""
CLASSES = (
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',
'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',
'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',
'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
'window-blind', 'window-other', 'wood')
PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
[64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
[128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
[64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
[64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
[0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
[64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
[64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
[128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
[0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
[0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
[64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
[0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
[0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
[192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
[64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
[0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
[64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
[64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
[0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
[192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
[0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
[64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
[64, 192, 96], [64, 160, 64], [64, 64, 0]]
def __init__(self, **kwargs):
super(COCOStuffDataset, self).__init__(
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/custom.py
================================================
import os
import os.path as osp
from functools import reduce
import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.
.. code-block:: none
├── data
│ ├── my_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{seg_map_suffix}
│ │ │ │ ├── yyy{seg_map_suffix}
│ │ │ │ ├── zzz{seg_map_suffix}
│ │ │ ├── val
The img/gt_semantic_seg pair of CustomDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/tutorials/new_dataset.md`` for more details.
Args:
pipeline (list[dict]): Processing pipeline
img_dir (str): Path to image directory
img_suffix (str): Suffix of images. Default: '.jpg'
ann_dir (str, optional): Path to annotation directory. Default: None
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
split (str, optional): Split txt file. If split is specified, only
file with suffix in the splits will be loaded. Otherwise, all
images in img_dir/ann_dir will be loaded. Default: None
data_root (str, optional): Data root for img_dir/ann_dir. Default:
None.
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated.
Default: None
"""
CLASSES = None
PALETTE = None
def __init__(self,
pipeline,
img_dir,
img_suffix='.jpg',
ann_dir=None,
seg_map_suffix='.png',
split=None,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False,
classes=None,
palette=None):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
self.ann_dir = ann_dir
self.seg_map_suffix = seg_map_suffix
self.split = split
self.data_root = data_root
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.img_dir):
self.img_dir = osp.join(self.data_root, self.img_dir)
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
self.ann_dir = osp.join(self.data_root, self.ann_dir)
if not (self.split is None or osp.isabs(self.split)):
self.split = osp.join(self.data_root, self.split)
# load annotations
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
self.ann_dir,
self.seg_map_suffix, self.split)
def __len__(self):
"""Total number of samples of data."""
return len(self.img_infos)
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
split):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
with open(split) as f:
for line in f:
img_name = line.strip()
img_info = dict(filename=img_name + img_suffix)
if ann_dir is not None:
seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_map = img.replace(img_suffix, seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos
def get_ann_info(self, idx):
"""Get annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.img_infos[idx]['ann']
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
results['img_prefix'] = self.img_dir
results['seg_prefix'] = self.ann_dir
if self.custom_classes:
results['label_map'] = self.label_map
def __getitem__(self, idx):
"""Get training/test data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set
False).
"""
if self.test_mode:
return self.prepare_test_img(idx)
else:
return self.prepare_train_img(idx)
def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.img_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, idx):
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys intorduced by
piepline.
"""
img_info = self.img_infos[idx]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def format_results(self, results, **kwargs):
"""Place holder to format result to dataset specific output."""
pass
def get_gt_seg_maps(self, efficient_test=False):
"""Get ground truth segmentation maps for evaluation."""
gt_seg_maps = []
for img_info in self.img_infos:
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
if efficient_test:
gt_seg_map = seg_map
else:
gt_seg_map = mmcv.imread(
seg_map, flag='unchanged', backend='pillow')
gt_seg_maps.append(gt_seg_map)
return gt_seg_maps
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Default: None
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)
palette = self.get_palette_for_custom_classes(class_names, palette)
return class_names, palette
def get_palette_for_custom_classes(self, class_names, palette=None):
if self.label_map is not None:
# return subset of palette
palette = []
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
elif palette is None:
if self.PALETTE is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
else:
palette = self.PALETTE
return palette
def evaluate(self,
results,
metric='mIoU',
logger=None,
efficient_test=False,
**kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str, float]: Default metrics.
"""
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
gt_seg_maps = self.get_gt_seg_maps(efficient_test)
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
self.ignore_index,
metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])
print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)
for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
if mmcv.is_list_of(results, str):
for file_name in results:
os.remove(file_name)
return eval_results
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/dataset_wrappers.py
================================================
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item from original dataset."""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""The length is multiplied by ``times``"""
return self.times * self._ori_len
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/drive.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class DRIVEDataset(CustomDataset):
"""DRIVE dataset.
In segmentation map annotation for DRIVE, 0 stands for background, which is
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_manual1.png'.
"""
CLASSES = ('background', 'vessel')
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(DRIVEDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='_manual1.png',
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/hrf.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class HRFDataset(CustomDataset):
"""HRF dataset.
In segmentation map annotation for HRF, 0 stands for background, which is
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'.png'.
"""
CLASSES = ('background', 'vessel')
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(HRFDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pascal_context.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class PascalContextDataset(CustomDataset):
"""PascalContext dataset.
In segmentation map annotation for PascalContext, 0 stands for background,
which is included in 60 categories. ``reduce_zero_label`` is fixed to
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png'.
Args:
split (str): Split txt file for PascalContext.
"""
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building',
'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door',
'fence', 'floor', 'flower', 'food', 'grass', 'ground',
'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform',
'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky',
'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water',
'window', 'wood')
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
def __init__(self, split, **kwargs):
super(PascalContextDataset, self).__init__(
img_suffix='.jpg',
seg_map_suffix='.png',
split=split,
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir) and self.split is not None
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/__init__.py
================================================
from .compose import Compose
from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip,
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/compose.py
================================================
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose(object):
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/formating.py
================================================
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..builder import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@PIPELINES.register_module()
class ToTensor(object):
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class ImageToTensor(object):
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class Transpose(object):
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@PIPELINES.register_module()
class ToDataContainer(object):
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
Args:
fields (Sequence[dict]): Each field is a dict like
``dict(key='xxx', **kwargs)``. The ``key`` in result will
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
Default: ``(dict(key='img', stack=True),
dict(key='gt_semantic_seg'))``.
"""
def __init__(self,
fields=(dict(key='img',
stack=True), dict(key='gt_semantic_seg'))):
self.fields = fields
def __call__(self, results):
"""Call function to convert data in results to
:obj:`mmcv.DataContainer`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted to
:obj:`mmcv.DataContainer`.
"""
for field in self.fields:
field = field.copy()
key = field.pop('key')
results[key] = DC(results[key], **field)
return results
def __repr__(self):
return self.__class__.__name__ + f'(fields={self.fields})'
@PIPELINES.register_module()
class DefaultFormatBundle(object):
"""Default formatting bundle.
It simplifies the pipeline of formatting common fields, including "img"
and "gt_semantic_seg". These fields are formatted as follows.
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
(3)to DataContainer (stack=True)
"""
def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with
default bundle.
"""
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results['img'] = DC(to_tensor(img), stack=True)
if 'gt_semantic_seg' in results:
# convert to long
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None,
...].astype(np.int64)),
stack=True)
return results
def __repr__(self):
return self.__class__.__name__
@PIPELINES.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "gt_semantic_seg".
The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes:
- "img_shape": shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- "scale_factor": a float indicating the preprocessing scale
- "flip": a boolean indicating if image flip transform was used
- "filename": path to the image file
- "ori_shape": original shape of the image as a tuple (h, w, c)
- "pad_shape": image shape after padding
- "img_norm_cfg": a dict of normalization information:
- mean - per channel mean subtraction
- std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')``
"""
def __init__(self,
keys,
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
"""Call function to collect keys in results. The keys in ``meta_keys``
will be converted to :obj:mmcv.DataContainer.
Args:
results (dict): Result dict contains the data to collect.
Returns:
dict: The result dict contains the following keys
- keys in``self.keys``
- ``img_metas``
"""
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, meta_keys={self.meta_keys})'
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/loading.py
================================================
import os.path as osp
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
'cv2'
"""
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk'),
imdecode_backend='cv2'):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
self.imdecode_backend = imdecode_backend
def __call__(self, results):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results.get('img_prefix') is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(to_float32={self.to_float32},'
repr_str += f"color_type='{self.color_type}',"
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
return repr_str
@PIPELINES.register_module()
class LoadAnnotations(object):
"""Load annotations for semantic segmentation.
Args:
reduce_zero_label (bool): Whether reduce all label value by 1.
Usually used for datasets where 0 is background label.
Default: False.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
'pillow'
"""
def __init__(self,
reduce_zero_label=False,
file_client_args=dict(backend='disk'),
imdecode_backend='pillow'):
self.reduce_zero_label = reduce_zero_label
self.file_client_args = file_client_args.copy()
self.file_client = None
self.imdecode_backend = imdecode_backend
def __call__(self, results):
"""Call function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results.get('seg_prefix', None) is not None:
filename = osp.join(results['seg_prefix'],
results['ann_info']['seg_map'])
else:
filename = results['ann_info']['seg_map']
img_bytes = self.file_client.get(filename)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('label_map', None) is not None:
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
results['gt_semantic_seg'] = gt_semantic_seg
results['seg_fields'].append('gt_semantic_seg')
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
return repr_str
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/test_time_aug.py
================================================
import warnings
import mmcv
from ..builder import PIPELINES
from .compose import Compose
@PIPELINES.register_module()
class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping.
An example configuration is as followed:
.. code-block::
img_scale=(2048, 1024),
img_ratios=[0.5, 1.0],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
flip=[False, True, False, True]
...
)
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (None | tuple | list[tuple]): Images scales for resizing.
img_ratios (float | list[float]): Image ratios for resizing
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def __init__(self,
transforms,
img_scale,
img_ratios=None,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
if img_ratios is not None:
img_ratios = img_ratios if isinstance(img_ratios,
list) else [img_ratios]
assert mmcv.is_list_of(img_ratios, float)
if img_scale is None:
# mode 1: given img_scale=None and a range of image ratio
self.img_scale = None
assert mmcv.is_list_of(img_ratios, float)
elif isinstance(img_scale, tuple) and mmcv.is_list_of(
img_ratios, float):
assert len(img_scale) == 2
# mode 2: given a scale and a range of image ratio
self.img_scale = [(int(img_scale[0] * ratio),
int(img_scale[1] * ratio))
for ratio in img_ratios]
else:
# mode 3: given multiple scales
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
self.flip = flip
self.img_ratios = img_ratios
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
aug_data = []
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
h, w = results['img'].shape[:2]
img_scale = [(int(w * ratio), int(h * ratio))
for ratio in self.img_ratios]
else:
img_scale = self.img_scale
flip_aug = [False, True] if self.flip else [False]
for scale in img_scale:
for flip in flip_aug:
for direction in self.flip_direction:
_results = results.copy()
_results['scale'] = scale
_results['flip'] = flip
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
repr_str += f'flip_direction={self.flip_direction}'
return repr_str
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/transforms.py
================================================
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random
from ..builder import PIPELINES
@PIPELINES.register_module()
class Resize(object):
"""Resize images & seg.
This transform resizes the input image to some scale. If the input dict
contains the key "scale", then the scale in the input dict is used,
otherwise the specified scale in the init method is used.
``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple
(multi-scale). There are 4 multiscale modes:
- ``ratio_range is not None``:
1. When img_scale is None, img_scale is the shape of image in results
(img_scale = results['img'].shape[:2]) and the image is resized based
on the original size. (mode 1)
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
the ratio range and multiply it with the image scale. (mode 2)
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
scale from the a range. (mode 3)
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
scale from multiple scales. (mode 4)
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given img_scale=None and a range of image ratio
# mode 2: given a scale and a range of image ratio
assert self.img_scale is None or len(self.img_scale) == 1
else:
# mode 3 and 4: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
``img_scale`` is sampled scale and None is just a placeholder
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
``scale`` is sampled ratio multiplied with ``img_scale`` and
None is just a placeholder to be consistent with
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
if self.img_scale is None:
h, w = results['img'].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'], results['scale'], return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape # in case that there is no padding
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
results[key] = gt_seg
def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(img_scale={self.img_scale}, '
f'multiscale_mode={self.multiscale_mode}, '
f'ratio_range={self.ratio_range}, '
f'keep_ratio={self.keep_ratio})')
return repr_str
@PIPELINES.register_module()
class RandomFlip(object):
"""Flip the image & seg.
If the input dict contains the key "flip", then the flag will be used,
otherwise it will be randomly decided by a ratio specified in the init
method.
Args:
prob (float, optional): The flipping probability. Default: None.
direction(str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
"""
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
def __init__(self, prob=None, direction='horizontal'):
self.prob = prob
self.direction = direction
if prob is not None:
assert prob >= 0 and prob <= 1
assert direction in ['horizontal', 'vertical']
def __call__(self, results):
"""Call function to flip bounding boxes, masks, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added into
result dict.
"""
if 'flip' not in results:
flip = True if np.random.rand() < self.prob else False
results['flip'] = flip
if 'flip_direction' not in results:
results['flip_direction'] = self.direction
if results['flip']:
# flip image
results['img'] = mmcv.imflip(
results['img'], direction=results['flip_direction'])
# flip segs
for key in results.get('seg_fields', []):
# use copy() to make numpy stride positive
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction']).copy()
return results
def __repr__(self):
return self.__class__.__name__ + f'(prob={self.prob})'
@PIPELINES.register_module()
class Pad(object):
"""Pad the image & mask.
There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
Args:
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
"""
def __init__(self,
size=None,
size_divisor=None,
pad_val=0,
seg_pad_val=255):
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
# only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None
assert size is None or size_divisor is None
def _pad_img(self, results):
"""Pad images according to ``self.size``."""
if self.size is not None:
padded_img = mmcv.impad(
results['img'], shape=self.size, pad_val=self.pad_val)
elif self.size_divisor is not None:
padded_img = mmcv.impad_to_multiple(
results['img'], self.size_divisor, pad_val=self.pad_val)
results['img'] = padded_img
results['pad_shape'] = padded_img.shape
results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor
def _pad_seg(self, results):
"""Pad masks according to ``results['pad_shape']``."""
for key in results.get('seg_fields', []):
results[key] = mmcv.impad(
results[key],
shape=results['pad_shape'][:2],
pad_val=self.seg_pad_val)
def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
self._pad_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
f'pad_val={self.pad_val})'
return repr_str
@PIPELINES.register_module()
class Normalize(object):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
f'{self.to_rgb})'
return repr_str
@PIPELINES.register_module()
class Rerange(object):
"""Rerange the image pixel value.
Args:
min_value (float or int): Minimum value of the reranged image.
Default: 0.
max_value (float or int): Maximum value of the reranged image.
Default: 255.
"""
def __init__(self, min_value=0, max_value=255):
assert isinstance(min_value, float) or isinstance(min_value, int)
assert isinstance(max_value, float) or isinstance(max_value, int)
assert min_value < max_value
self.min_value = min_value
self.max_value = max_value
def __call__(self, results):
"""Call function to rerange images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
img = results['img']
img_min_value = np.min(img)
img_max_value = np.max(img)
assert img_min_value < img_max_value
# rerange to [0, 1]
img = (img - img_min_value) / (img_max_value - img_min_value)
# rerange to [min_value, max_value]
img = img * (self.max_value - self.min_value) + self.min_value
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
return repr_str
@PIPELINES.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size
def __call__(self, results):
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str
@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.
Args:
crop_size (tuple): Expected size after cropping, (h, w).
cat_max_ratio (float): The maximum ratio that single category could
occupy.
"""
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
def get_crop_bbox(self, img):
"""Randomly get a crop bounding box."""
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
def crop(self, img, crop_bbox):
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def __call__(self, results):
"""Call function to randomly crop images, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
img = results['img']
crop_bbox = self.get_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img)
# crop the image
img = self.crop(img, crop_bbox)
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@PIPELINES.register_module()
class RandomRotate(object):
"""Rotate the image & seg.
Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""
def __init__(self,
prob,
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.prob = prob
assert prob >= 0 and prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound
def __call__(self, results):
"""Call function to rotate image, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
rotate = True if np.random.rand() < self.prob else False
degree = np.random.uniform(min(*self.degree), max(*self.degree))
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
results['img'],
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound)
# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str
@PIPELINES.register_module()
class RGB2Gray(object):
"""Convert RGB image to grayscale image.
This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
input channels.
Args:
out_channels (int): Expected number of output channels after
transforming. Default: None.
weights (tuple[float]): The weights to calculate the weighted mean.
Default: (0.299, 0.587, 0.114).
"""
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
assert out_channels is None or out_channels > 0
self.out_channels = out_channels
assert isinstance(weights, tuple)
for item in weights:
assert isinstance(item, (float, int))
self.weights = weights
def __call__(self, results):
"""Call function to convert RGB image to grayscale image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with grayscale image.
"""
img = results['img']
assert len(img.shape) == 3
assert img.shape[2] == len(self.weights)
weights = np.array(self.weights).reshape((1, 1, -1))
img = (img * weights).sum(2, keepdims=True)
if self.out_channels is None:
img = img.repeat(weights.shape[2], axis=2)
else:
img = img.repeat(self.out_channels, axis=2)
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(out_channels={self.out_channels}, ' \
f'weights={self.weights})'
return repr_str
@PIPELINES.register_module()
class AdjustGamma(object):
"""Using gamma correction to process the image.
Args:
gamma (float or int): Gamma value used in gamma correction.
Default: 1.0.
"""
def __init__(self, gamma=1.0):
assert isinstance(gamma, float) or isinstance(gamma, int)
assert gamma > 0
self.gamma = gamma
inv_gamma = 1.0 / gamma
self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8')
def __call__(self, results):
"""Call function to process the image with gamma correction.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
results['img'] = mmcv.lut_transform(
np.array(results['img'], dtype=np.uint8), self.table)
return results
def __repr__(self):
return self.__class__.__name__ + f'(gamma={self.gamma})'
@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Args:
scale_factor (float): The scale factor of the final output.
"""
def __init__(self, scale_factor=1):
self.scale_factor = scale_factor
def __call__(self, results):
"""Call function to scale the semantic segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with semantic segmentation map scaled.
"""
for key in results.get('seg_fields', []):
if self.scale_factor != 1:
results[key] = mmcv.imrescale(
results[key], self.scale_factor, interpolation='nearest')
return results
def __repr__(self):
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
@PIPELINES.register_module()
class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self, img, alpha=1, beta=0):
"""Multiple with alpha and add beat with clip."""
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
"""Brightness distortion."""
if random.randint(2):
return self.convert(
img,
beta=random.uniform(-self.brightness_delta,
self.brightness_delta))
return img
def contrast(self, img):
"""Contrast distortion."""
if random.randint(2):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img
def saturation(self, img):
"""Saturation distortion."""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper))
img = mmcv.hsv2bgr(img)
return img
def hue(self, img):
"""Hue distortion."""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180
img = mmcv.hsv2bgr(img)
return img
def __call__(self, results):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = results['img']
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(brightness_delta={self.brightness_delta}, '
f'contrast_range=({self.contrast_lower}, '
f'{self.contrast_upper}), '
f'saturation_range=({self.saturation_lower}, '
f'{self.saturation_upper}), '
f'hue_delta={self.hue_delta})')
return repr_str
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/stare.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class STAREDataset(CustomDataset):
"""STARE dataset.
In segmentation map annotation for STARE, 0 stands for background, which is
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'.ah.png'.
"""
CLASSES = ('background', 'vessel')
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(STAREDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.ah.png',
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/voc.py
================================================
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class PascalVOCDataset(CustomDataset):
"""Pascal VOC dataset.
Args:
split (str): Split txt file for Pascal VOC.
"""
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
'train', 'tvmonitor')
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def __init__(self, split, **kwargs):
super(PascalVOCDataset, self).__init__(
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/__init__.py
================================================
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
'build_head', 'build_loss', 'build_segmentor'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/__init__.py
================================================
from .cgnet import CGNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
from .unet import UNet
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/cgnet.py
================================================
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
class GlobalContextExtractor(nn.Module):
"""Global Context Extractor for CGNet.
This class is employed to refine the joFint feature of both local feature
and surrounding context.
Args:
channel (int): Number of input feature channels.
reduction (int): Reductions for global context extractor. Default: 16.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self, channel, reduction=16, with_cp=False):
super(GlobalContextExtractor, self).__init__()
self.channel = channel
self.reduction = reduction
assert reduction >= 1 and channel >= reduction
self.with_cp = with_cp
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
def _inner_forward(x):
num_batch, num_channel = x.size()[:2]
y = self.avg_pool(x).view(num_batch, num_channel)
y = self.fc(y).view(num_batch, num_channel, 1, 1)
return x * y
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class ContextGuidedBlock(nn.Module):
"""Context Guided Block for CGNet.
This class consists of four components: local feature extractor,
surrounding feature extractor, joint feature extractor and global
context extractor.
Args:
in_channels (int): Number of input feature channels.
out_channels (int): Number of output feature channels.
dilation (int): Dilation rate for surrounding context extractor.
Default: 2.
reduction (int): Reduction for global context extractor. Default: 16.
skip_connect (bool): Add input to output or not. Default: True.
downsample (bool): Downsample the input to 1/2 or not. Default: False.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
dilation=2,
reduction=16,
skip_connect=True,
downsample=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
with_cp=False):
super(ContextGuidedBlock, self).__init__()
self.with_cp = with_cp
self.downsample = downsample
channels = out_channels if downsample else out_channels // 2
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
act_cfg['num_parameters'] = channels
kernel_size = 3 if downsample else 1
stride = 2 if downsample else 1
padding = (kernel_size - 1) // 2
self.conv1x1 = ConvModule(
in_channels,
channels,
kernel_size,
stride,
padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.f_loc = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=False)
self.f_sur = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=dilation,
groups=channels,
dilation=dilation,
bias=False)
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
self.activate = nn.PReLU(2 * channels)
if downsample:
self.bottleneck = build_conv_layer(
conv_cfg,
2 * channels,
out_channels,
kernel_size=1,
bias=False)
self.skip_connect = skip_connect and not downsample
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
def forward(self, x):
def _inner_forward(x):
out = self.conv1x1(x)
loc = self.f_loc(out)
sur = self.f_sur(out)
joi_feat = torch.cat([loc, sur], 1) # the joint feature
joi_feat = self.bn(joi_feat)
joi_feat = self.activate(joi_feat)
if self.downsample:
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
# f_glo is employed to refine the joint feature
out = self.f_glo(joi_feat)
if self.skip_connect:
return x + out
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class InputInjection(nn.Module):
"""Downsampling module for CGNet."""
def __init__(self, num_downsampling):
super(InputInjection, self).__init__()
self.pool = nn.ModuleList()
for i in range(num_downsampling):
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
def forward(self, x):
for pool in self.pool:
x = pool(x)
return x
@BACKBONES.register_module()
class CGNet(nn.Module):
"""CGNet backbone.
A Light-weight Context Guided Network for Semantic Segmentation
arXiv: https://arxiv.org/abs/1811.08201
Args:
in_channels (int): Number of input image channels. Normally 3.
num_channels (tuple[int]): Numbers of feature channels at each stages.
Default: (32, 64, 128).
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
Default: (3, 21).
dilations (tuple[int]): Dilation rate for surrounding context
extractors at stage 1 and stage 2. Default: (2, 4).
reductions (tuple[int]): Reductions for global context extractors at
stage 1 and stage 2. Default: (8, 16).
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels=3,
num_channels=(32, 64, 128),
num_blocks=(3, 21),
dilations=(2, 4),
reductions=(8, 16),
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
norm_eval=False,
with_cp=False):
super(CGNet, self).__init__()
self.in_channels = in_channels
self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len(
self.num_channels) == 3
self.num_blocks = num_blocks
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
self.dilations = dilations
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
self.reductions = reductions
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
self.act_cfg['num_parameters'] = num_channels[0]
self.norm_eval = norm_eval
self.with_cp = with_cp
cur_channels = in_channels
self.stem = nn.ModuleList()
for i in range(3):
self.stem.append(
ConvModule(
cur_channels,
num_channels[0],
3,
2 if i == 0 else 1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
cur_channels = num_channels[0]
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
cur_channels += in_channels
self.norm_prelu_0 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 1
self.level1 = nn.ModuleList()
for i in range(num_blocks[0]):
self.level1.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[1],
num_channels[1],
dilations[0],
reductions[0],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[1] + in_channels
self.norm_prelu_1 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 2
self.level2 = nn.ModuleList()
for i in range(num_blocks[1]):
self.level2.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[2],
num_channels[2],
dilations[1],
reductions[1],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[2]
self.norm_prelu_2 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
def forward(self, x):
output = []
# stage 0
inp_2x = self.inject_2x(x)
inp_4x = self.inject_4x(x)
for layer in self.stem:
x = layer(x)
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
output.append(x)
# stage 1
for i, layer in enumerate(self.level1):
x = layer(x)
if i == 0:
down1 = x
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
output.append(x)
# stage 2
for i, layer in enumerate(self.level2):
x = layer(x)
if i == 0:
down2 = x
x = self.norm_prelu_2(torch.cat([down2, x], 1))
output.append(x)
return output
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.PReLU):
constant_init(m, 0)
else:
raise TypeError('pretrained must be a str or None')
def train(self, mode=True):
"""Convert the model into training mode whill keeping the normalization
layer freezed."""
super(CGNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/fast_scnn.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize
from ..builder import BACKBONES
from ..utils.inverted_residual import InvertedResidual
class LearningToDownsample(nn.Module):
"""Learning to downsample module.
Args:
in_channels (int): Number of input channels.
dw_channels (tuple[int]): Number of output channels of the first and
the second depthwise conv (dwconv) layers.
out_channels (int): Number of output channels of the whole
'learning to downsample' module.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
"""
def __init__(self,
in_channels,
dw_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')):
super(LearningToDownsample, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
dw_channels1 = dw_channels[0]
dw_channels2 = dw_channels[1]
self.conv = ConvModule(
in_channels,
dw_channels1,
3,
stride=2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.dsconv1 = DepthwiseSeparableConvModule(
dw_channels1,
dw_channels2,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg)
self.dsconv2 = DepthwiseSeparableConvModule(
dw_channels2,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg)
def forward(self, x):
x = self.conv(x)
x = self.dsconv1(x)
x = self.dsconv2(x)
return x
class GlobalFeatureExtractor(nn.Module):
"""Global feature extractor module.
Args:
in_channels (int): Number of input channels of the GFE module.
Default: 64
block_channels (tuple[int]): Tuple of ints. Each int specifies the
number of output channels of each Inverted Residual module.
Default: (64, 96, 128)
out_channels(int): Number of output channels of the GFE module.
Default: 128
expand_ratio (int): Adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
Default: 6
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
number of times each Inverted Residual module is repeated.
The repeated Inverted Residual modules are called a 'group'.
Default: (3, 3, 3)
strides (tuple[int]): Tuple of ints. Each int specifies
the downsampling factor of each 'group'.
Default: (2, 2, 1)
pool_scales (tuple[int]): Tuple of ints. Each int specifies
the parameter required in 'global average pooling' within PPM.
Default: (1, 2, 3, 6)
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
expand_ratio=6,
num_blocks=(3, 3, 3),
strides=(2, 2, 1),
pool_scales=(1, 2, 3, 6),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(GlobalFeatureExtractor, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
assert len(block_channels) == len(num_blocks) == 3
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
num_blocks[0], strides[0],
expand_ratio)
self.bottleneck2 = self._make_layer(block_channels[0],
block_channels[1], num_blocks[1],
strides[1], expand_ratio)
self.bottleneck3 = self._make_layer(block_channels[1],
block_channels[2], num_blocks[2],
strides[2], expand_ratio)
self.ppm = PPM(
pool_scales,
block_channels[2],
block_channels[2] // 4,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=align_corners)
self.out = ConvModule(
block_channels[2] * 2,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _make_layer(self,
in_channels,
out_channels,
blocks,
stride=1,
expand_ratio=6):
layers = [
InvertedResidual(
in_channels,
out_channels,
stride,
expand_ratio,
norm_cfg=self.norm_cfg)
]
for i in range(1, blocks):
layers.append(
InvertedResidual(
out_channels,
out_channels,
1,
expand_ratio,
norm_cfg=self.norm_cfg))
return nn.Sequential(*layers)
def forward(self, x):
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = torch.cat([x, *self.ppm(x)], dim=1)
x = self.out(x)
return x
class FeatureFusionModule(nn.Module):
"""Feature fusion module.
Args:
higher_in_channels (int): Number of input channels of the
higher-resolution branch.
lower_in_channels (int): Number of input channels of the
lower-resolution branch.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
higher_in_channels,
lower_in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(FeatureFusionModule, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.dwconv = ConvModule(
lower_in_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.conv_lower_res = ConvModule(
out_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.conv_higher_res = ConvModule(
higher_in_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.relu = nn.ReLU(True)
def forward(self, higher_res_feature, lower_res_feature):
lower_res_feature = resize(
lower_res_feature,
size=higher_res_feature.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
lower_res_feature = self.dwconv(lower_res_feature)
lower_res_feature = self.conv_lower_res(lower_res_feature)
higher_res_feature = self.conv_higher_res(higher_res_feature)
out = higher_res_feature + lower_res_feature
return self.relu(out)
@BACKBONES.register_module()
class FastSCNN(nn.Module):
"""Fast-SCNN Backbone.
Args:
in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels (tuple[int]): Number of output channels after
the first conv layer & the second conv layer in
Learning-To-Downsample (LTD) module.
Default: (32, 48).
global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
Default: 64.
global_block_channels (tuple[int]): Tuple of integers that describe
the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE.
Default: (64, 96, 128).
global_block_strides (tuple[int]): Tuple of integers
that describe the strides (downsampling factors) for each of the
MobileNet-v2 bottleneck residual blocks in GFE.
Default: (2, 2, 1).
global_out_channels (int): Number of output channels of GFE.
Default: 128.
higher_in_channels (int): Number of input channels of the higher
resolution branch in FFM.
Equal to global_in_channels.
Default: 64.
lower_in_channels (int): Number of input channels of the lower
resolution branch in FFM.
Equal to global_out_channels.
Default: 128.
fusion_out_channels (int): Number of output channels of FFM.
Default: 128.
out_indices (tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads.
Default: (0, 1, 2).
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
in_channels=3,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same \
with Lower Input Channels!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels
self.global_block_strides = global_block_strides
self.global_out_channels = global_out_channels
self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels
self.fusion_out_channels = fusion_out_channels
self.out_indices = out_indices
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.learning_to_downsample = LearningToDownsample(
in_channels,
downsample_dw_channels,
global_in_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_feature_extractor = GlobalFeatureExtractor(
global_in_channels,
global_block_channels,
global_out_channels,
strides=self.global_block_strides,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.feature_fusion = FeatureFusionModule(
higher_in_channels,
lower_in_channels,
fusion_out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
def forward(self, x):
higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features)
fusion_output = self.feature_fusion(higher_res_features,
lower_res_features)
outs = [higher_res_features, lower_res_features, fusion_output]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/hrnet.py
================================================
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample, resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HRModule, self).__init__()
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
"""Check branches configuration."""
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
f'{len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
f'{len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
f'{len(in_channels)})'
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
"""Build one branch."""
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch."""
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
"""Build fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
# we set align_corners=False for HRNet
Upsample(
scale_factor=2**(j - i),
mode='bilinear',
align_corners=False)))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
elif j > i:
y = y + resize(
self.fuse_layers[i][j](x[j]),
size=x[i].shape[2:],
mode='bilinear',
align_corners=False)
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module()
class HRNet(nn.Module):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
arXiv: https://arxiv.org/abs/1904.04514
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmseg.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
"""Make each layer."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage."""
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
"""Convert the model into training mode whill keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v2.py
================================================
import logging
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module()
class MobileNetV2(nn.Module):
"""MobileNetV2 backbone.
Args:
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
strides (Sequence[int], optional): Strides of the first block of each
layer. If not specified, default config in ``arch_setting`` will
be used.
dilations (Sequence[int]): Dilation of each layer.
out_indices (None or Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks.
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
def __init__(self,
widen_factor=1.,
strides=(1, 2, 2, 2, 1, 2, 1),
dilations=(1, 1, 1, 1, 1, 1, 1),
out_indices=(1, 2, 4, 6),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False):
super(MobileNetV2, self).__init__()
self.widen_factor = widen_factor
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == len(self.arch_settings)
self.out_indices = out_indices
for index in out_indices:
if index not in range(0, 7):
raise ValueError('the item in out_indices must in '
f'range(0, 8). But received {index}')
if frozen_stages not in range(-1, 7):
raise ValueError('frozen_stages must be in range(-1, 7). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(32 * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
expand_ratio, channel, num_blocks = layer_cfg
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
expand_ratio=expand_ratio)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
def make_layer(self, out_channels, num_blocks, stride, dilation,
expand_ratio):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): Number of blocks.
stride (int): Stride of the first block.
dilation (int): Dilation of the first block.
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio.
"""
layers = []
for i in range(num_blocks):
layers.append(
InvertedResidual(
self.in_channels,
out_channels,
stride if i == 0 else 1,
expand_ratio=expand_ratio,
dilation=dilation if i == 0 else 1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileNetV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v3.py
================================================
import logging
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module()
class MobileNetV3(nn.Module):
"""MobileNetV3 backbone.
This backbone is the improved implementation of `Searching for MobileNetV3
`_.
Args:
arch (str): Architechture of mobilnetv3, from {'small', 'large'}.
Default: 'small'.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (tuple[int]): Output from which layer.
Default: (0, 1, 12).
frozen_stages (int): Stages to be frozen (all param fixed).
Defualt: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Defualt: False.
"""
# Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride]
arch_settings = {
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
[3, 88, 24, False, 'ReLU', 1],
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
[5, 240, 40, True, 'HSwish', 1],
[5, 240, 40, True, 'HSwish', 1],
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
[5, 144, 48, True, 'HSwish', 1],
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
[5, 576, 96, True, 'HSwish', 1],
[5, 576, 96, True, 'HSwish', 1]],
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
[3, 72, 24, False, 'ReLU', 1],
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
[5, 120, 40, True, 'ReLU', 1],
[5, 120, 40, True, 'ReLU', 1],
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
[3, 200, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
[3, 672, 112, True, 'HSwish', 1],
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
[5, 960, 160, True, 'HSwish', 1],
[5, 960, 160, True, 'HSwish', 1]]
} # yapf: disable
def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(0, 1, 12),
frozen_stages=-1,
reduction_factor=1,
norm_eval=False,
with_cp=False):
super(MobileNetV3, self).__init__()
assert arch in self.arch_settings
assert isinstance(reduction_factor, int) and reduction_factor > 0
assert mmcv.is_tuple_of(out_indices, int)
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch]) + 2):
raise ValueError(
'the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])+2}). '
f'But received {index}')
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])+2}). '
f'But received {frozen_stages}')
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.reduction_factor = reduction_factor
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layers = self._make_layer()
def _make_layer(self):
layers = []
# build the first layer (layer0)
in_channels = 16
layer = ConvModule(
in_channels=3,
out_channels=in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(type='Conv2dAdaptivePadding'),
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
self.add_module('layer0', layer)
layers.append('layer0')
layer_setting = self.arch_settings[self.arch]
for i, params in enumerate(layer_setting):
(kernel_size, mid_channels, out_channels, with_se, act,
stride) = params
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
i >= 8:
mid_channels = mid_channels // self.reduction_factor
out_channels = out_channels // self.reduction_factor
if with_se:
se_cfg = dict(
channels=mid_channels,
ratio=4,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
else:
se_cfg = None
layer = InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
with_expand_conv=(in_channels != mid_channels),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act),
with_cp=self.with_cp)
in_channels = out_channels
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, layer)
layers.append(layer_name)
# build the last layer
# block5 layer12 os=32 for small model
# block6 layer16 os=32 for large model
layer = ConvModule(
in_channels=in_channels,
out_channels=576 if self.arch == 'small' else 960,
kernel_size=1,
stride=1,
dilation=4,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
layer_name = 'layer{}'.format(len(layer_setting) + 1)
self.add_module(layer_name, layer)
layers.append(layer_name)
# next, convert backbone MobileNetV3 to a semantic segmentation version
if self.arch == 'small':
self.layer4.depthwise_conv.conv.stride = (1, 1)
self.layer9.depthwise_conv.conv.stride = (1, 1)
for i in range(4, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 9:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
else:
self.layer7.depthwise_conv.conv.stride = (1, 1)
self.layer13.depthwise_conv.conv.stride = (1, 1)
for i in range(7, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 13:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return outs
def _freeze_stages(self):
for i in range(self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileNetV3, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnest.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d in ResNeSt.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
dcn (dict): Config dict for DCN. Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None):
super(SplitAttentionConv2d, self).__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.with_dcn = dcn is not None
self.dcn = dcn
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if self.with_dcn and not fallback_on_stride:
assert conv_cfg is None, 'conv_cfg must be None for DCN'
conv_cfg = dcn
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
"""nn.Module: the normalization layer named "norm0" """
return getattr(self, self.norm0_name)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
batch = x.size(0)
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
inplane (int): Input planes of this block.
planes (int): Middle planes of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Key word arguments for base class.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
"""Bottleneck block for ResNeSt."""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.with_modulated_dcn = False
self.conv2 = SplitAttentionConv2d(
width,
width,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=self.dcn)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Args:
groups (int): Number of groups of Bottleneck. Default: 1
base_width (int): Base width of Bottleneck. Default: 4
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Keyword arguments for ResNet.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.base_width = base_width
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super(ResNeSt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnet.py
================================================
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import ResLayer
class BasicBlock(nn.Module):
"""Basic block for ResNet."""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(BasicBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list)
if plugins is not None:
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
assert all(p['position'] in allowed_position for p in plugins)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dcn = dcn
self.with_dcn = dcn is not None
self.plugins = plugins
self.with_plugins = plugins is not None
if self.with_plugins:
# collect plugins for conv1/conv2/conv3
self.after_conv1_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv1'
]
self.after_conv2_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv2'
]
self.after_conv3_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv3'
]
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if self.with_plugins:
self.after_conv1_plugin_names = self.make_block_plugins(
planes, self.after_conv1_plugins)
self.after_conv2_plugin_names = self.make_block_plugins(
planes, self.after_conv2_plugins)
self.after_conv3_plugin_names = self.make_block_plugins(
planes * self.expansion, self.after_conv3_plugins)
def make_block_plugins(self, in_channels, plugins):
"""make plugins for block.
Args:
in_channels (int): Input channels of plugin.
plugins (list[dict]): List of plugins cfg to build.
Returns:
list[str]: List of the names of plugin.
"""
assert isinstance(plugins, list)
plugin_names = []
for plugin in plugins:
plugin = plugin.copy()
name, layer = build_plugin_layer(
plugin,
in_channels=in_channels,
postfix=plugin.pop('postfix', ''))
assert not hasattr(self, name), f'duplicate plugin {name}'
self.add_module(name, layer)
plugin_names.append(name)
return plugin_names
def forward_plugin(self, x, plugin_names):
"""Forward function for plugins."""
out = x
for name in plugin_names:
out = getattr(self, name)(x)
return out
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
@property
def norm3(self):
"""nn.Module: normalization layer after the third convolution layer"""
return getattr(self, self.norm3_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default" 3.
stem_channels (int): Number of stem channels. Default: 64.
base_channels (int): Number of base channels of res layer. Default: 64.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert plugin,
options: 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
stage. Default: None
contract_dilation (bool): Whether contract first dilation of each layer
Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmseg.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
stem_channels=64,
base_channels=64,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
multi_grid=None,
contract_dilation=False,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.multi_grid = multi_grid
self.contract_dilation = contract_dilation
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
if plugins is not None:
stage_plugins = self.make_stage_plugins(plugins, i)
else:
stage_plugins = None
# multi grid is applied to last layer only
stage_multi_grid = multi_grid if i == len(
self.stage_blocks) - 1 else None
planes = base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=stage_plugins,
multi_grid=stage_multi_grid,
contract_dilation=contract_dilation)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i+1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)
def make_stage_plugins(self, plugins, stage_idx):
"""make plugins for ResNet 'stage_idx'th stage .
Currently we support to insert 'context_block',
'empirical_attention_block', 'nonlocal_block' into the backbone like
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
Bottleneck.
An example of plugins format could be :
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True),
... position='after_conv2'),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='1'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='2'),
... stages=(True, True, True, True),
... position='after_conv3')
... ]
>>> self = ResNet(depth=18)
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
>>> assert len(stage_plugins) == 3
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
conv1-> conv2->conv3->yyy->zzz1->zzz2
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
If stages is missing, the plugin would be applied to all stages.
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
Returns:
list[dict]: Plugins for current stage
"""
stage_plugins = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
# whether to insert plugin into current stage
if stages is None or stages[stage_idx]:
stage_plugins.append(plugin)
return stage_plugins
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(**kwargs)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels):
"""Make stem layer for ResNet."""
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels)[1],
nn.ReLU(inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@BACKBONES.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
References:
.. [1] https://arxiv.org/pdf/1812.01187.pdf
"""
def __init__(self, **kwargs):
super(ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super(ResNetV1d, self).__init__(
deep_stem=True, avg_down=True, **kwargs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnext.py
================================================
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmseg.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``"""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/unet.py
================================================
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import UpConvBlock
class BasicConvBlock(nn.Module):
"""Basic convolutional block for UNet.
This module consists of several plain convolutional layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers. Default: 2.
stride (int): Whether use stride convolution to downsample
the input feature map. If stride=2, it only uses stride convolution
in the first convolutional layer to downsample the input feature
map. Options are 1 or 2. Default: 1.
dilation (int): Whether use dilated convolution to expand the
receptive field. Set dilation rate of each convolutional layer and
the dilation rate of the first convolutional layer is always 1.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
dcn (bool): Use deformable convoluton in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
dcn=None,
plugins=None):
super(BasicConvBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.with_cp = with_cp
convs = []
for i in range(num_convs):
convs.append(
ConvModule(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride if i == 0 else 1,
dilation=1 if i == 0 else dilation,
padding=1 if i == 0 else dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.convs = nn.Sequential(*convs)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.convs, x)
else:
out = self.convs(x)
return out
@UPSAMPLE_LAYERS.register_module()
class DeconvModule(nn.Module):
"""Deconvolution upsample module in decoder for UNet (2X upsample).
This module uses deconvolution to upsample feature map in the decoder
of UNet.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
kernel_size=4,
scale_factor=2):
super(DeconvModule, self).__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
self.with_cp = with_cp
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
activate = build_activation_layer(act_cfg)
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.deconv_upsamping, x)
else:
out = self.deconv_upsamping(x)
return out
@UPSAMPLE_LAYERS.register_module()
class InterpConv(nn.Module):
"""Interpolation upsample module in decoder for UNet.
This module uses interpolation to upsample feature map in the decoder
of UNet. It consists of one interpolation upsample layer and one
convolutional layer. It can be one interpolation upsample layer followed
by one convolutional layer (conv_first=False) or one convolutional layer
followed by one interpolation upsample layer (conv_first=True).
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
conv_first (bool): Whether convolutional layer or interpolation
upsample layer first. Default: False. It means interpolation
upsample layer followed by one convolutional layer.
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
stride (int): Stride of the convolutional layer. Default: 1.
padding (int): Padding of the convolutional layer. Default: 1.
upsampe_cfg (dict): Interpolation config of the upsample layer.
Default: dict(
scale_factor=2, mode='bilinear', align_corners=False).
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
conv_cfg=None,
conv_first=False,
kernel_size=1,
stride=1,
padding=0,
upsampe_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)):
super(InterpConv, self).__init__()
self.with_cp = with_cp
conv = ConvModule(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsampe_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
self.interp_upsample = nn.Sequential(upsample, conv)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.interp_upsample, x)
else:
out = self.interp_upsample(x)
return out
@BACKBONES.register_module()
class UNet(nn.Module):
"""UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/pdf/1505.04597.pdf
Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
len(strides) is equal to num_stages. Normally the stride of the
first stage in encoder is 1. If strides[i]=2, it uses stride
convolution to downsample in the correspondance encoder stage.
Default: (1, 1, 1, 1, 1).
enc_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondance encoder stage.
Default: (2, 2, 2, 2, 2).
dec_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondance decoder stage.
Default: (2, 2, 2, 2).
downsamples (Sequence[int]): Whether use MaxPool to downsample the
feature map after the first stage of encoder
(stages: [1, num_stages)). If the correspondance encoder stage use
stride convolution (strides[i]=2), it will never use MaxPool to
downsample, even downsamples[i-1]=True.
Default: (True, True, True, True).
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
Default: (1, 1, 1, 1, 1).
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
Default: (1, 1, 1, 1).
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (bool): Use deformable convoluton in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
Notice:
The input image size should be devisible by the whole downsample rate
of the encoder. More detail of the whole downsample rate can be found
in UNet._check_input_devisible.
"""
def __init__(self,
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False,
dcn=None,
plugins=None):
super(UNet, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
'The length of strides should be equal to num_stages, '\
f'while the strides is {strides}, the length of '\
f'strides is {len(strides)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_num_convs) == num_stages, \
'The length of enc_num_convs should be equal to num_stages, '\
f'while the enc_num_convs is {enc_num_convs}, the length of '\
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_num_convs) == (num_stages-1), \
'The length of dec_num_convs should be equal to (num_stages-1), '\
f'while the dec_num_convs is {dec_num_convs}, the length of '\
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(downsamples) == (num_stages-1), \
'The length of downsamples should be equal to (num_stages-1), '\
f'while the downsamples is {downsamples}, the length of '\
f'downsamples is {len(downsamples)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_dilations) == num_stages, \
'The length of enc_dilations should be equal to num_stages, '\
f'while the enc_dilations is {enc_dilations}, the length of '\
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_dilations) == (num_stages-1), \
'The length of dec_dilations should be equal to (num_stages-1), '\
f'while the dec_dilations is {dec_dilations}, the length of '\
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
f'{num_stages}.'
self.num_stages = num_stages
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for i in range(num_stages):
enc_conv_block = []
if i != 0:
if strides[i] == 1 and downsamples[i - 1]:
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
upsample = (strides[i] != 1 or downsamples[i - 1])
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
upsample_cfg=upsample_cfg if upsample else None,
dcn=None,
plugins=None))
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
in_channels = base_channels * 2**i
def forward(self, x):
self._check_input_devisible(x)
enc_outs = []
for enc in self.encoder:
x = enc(x)
enc_outs.append(x)
dec_outs = [x]
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)
return dec_outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(UNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _check_input_devisible(self, x):
h, w = x.shape[-2:]
whole_downsample_rate = 1
for i in range(1, self.num_stages):
if self.strides[i] == 2 or self.downsamples[i - 1]:
whole_downsample_rate *= 2
assert (h % whole_downsample_rate == 0) \
and (w % whole_downsample_rate == 0),\
f'The input image size {(h, w)} should be devisible by the whole '\
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/builder.py
================================================
import warnings
from mmcv.utils import Registry, build_from_cfg
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
HEADS = Registry('head')
LOSSES = Registry('loss')
SEGMENTORS = Registry('segmentor')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg):
"""Build backbone."""
return build(cfg, BACKBONES)
def build_neck(cfg):
"""Build neck."""
return build(cfg, NECKS)
def build_head(cfg):
"""Build head."""
return build(cfg, HEADS)
def build_loss(cfg):
"""Build loss."""
return build(cfg, LOSSES)
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
"""Build segmentor."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/__init__.py
================================================
from .ann_head import ANNHead
from .apc_head import APCHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .lraspp_head import LRASPPHead
from .nl_head import NLHead
from .ocr_head import OCRHead
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .uper_head import UPerHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ann_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PPMConcat(nn.ModuleList):
"""Pyramid Pooling Module that only concat the features of each layer.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
"""
def __init__(self, pool_scales=(1, 3, 6, 8)):
super(PPMConcat, self).__init__(
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
def forward(self, feats):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(feats)
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
concat_outs = torch.cat(ppm_outs, dim=2)
return concat_outs
class SelfAttentionBlock(_SelfAttentionBlock):
"""Make a ANN used SelfAttentionBlock.
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_scale (int): The scale of query feature map.
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, share_key_query, query_scale, key_pool_scales,
conv_cfg, norm_cfg, act_cfg):
key_psp = PPMConcat(key_pool_scales)
if query_scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
else:
query_downsample = None
super(SelfAttentionBlock, self).__init__(
key_in_channels=low_in_channels,
query_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=share_key_query,
query_downsample=query_downsample,
key_downsample=key_psp,
key_query_num_convs=1,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
class AFNB(nn.Module):
"""Asymmetric Fusion Non-local Block(AFNB)
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
and query projection.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, query_scales, key_pool_scales, conv_cfg,
norm_cfg, act_cfg):
super(AFNB, self).__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=False,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
out_channels + high_in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, low_feats, high_feats):
"""Forward function."""
priors = [stage(high_feats, low_feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, high_feats], 1))
return output
class APNB(nn.Module):
"""Asymmetric Pyramid Non-local Block (APNB)
Args:
in_channels (int): Input channels of key/query feature,
which is the key feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, in_channels, channels, out_channels, query_scales,
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
super(APNB, self).__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=in_channels,
high_in_channels=in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=True,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
2 * in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, feats):
"""Forward function."""
priors = [stage(feats, feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, feats], 1))
return output
@HEADS.register_module()
class ANNHead(BaseDecodeHead):
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
This head is the implementation of `ANNNet
`_.
Args:
project_channels (int): Projection channels for Nonlocal.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): The pooling scales of key feature map.
Default: (1, 3, 6, 8).
"""
def __init__(self,
project_channels,
query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8),
**kwargs):
super(ANNHead, self).__init__(
input_transform='multiple_select', **kwargs)
assert len(self.in_channels) == 2
low_in_channels, high_in_channels = self.in_channels
self.project_channels = project_channels
self.fusion = AFNB(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
out_channels=high_in_channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
high_in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.context = APNB(
in_channels=self.channels,
out_channels=self.channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
low_feats, high_feats = self._transform_inputs(inputs)
output = self.fusion(low_feats, high_feats)
output = self.dropout(output)
output = self.bottleneck(output)
output = self.context(output)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/apc_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class ACM(nn.Module):
"""Adaptive Context Module used in APCNet.
Args:
pool_scale (int): Pooling scale used in Adaptive Context
Module to extract region fetures.
fusion (bool): Add one conv to fuse residual feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super(ACM, self).__init__()
self.pool_scale = pool_scale
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.pooled_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_info = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
self.residual_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
# [batch_size, channels, h, w]
x = self.input_redu_conv(x)
# [batch_size, channels, pool_scale, pool_scale]
pooled_x = self.pooled_redu_conv(pooled_x)
batch_size = x.size(0)
# [batch_size, pool_scale * pool_scale, channels]
pooled_x = pooled_x.view(batch_size, self.channels,
-1).permute(0, 2, 1).contiguous()
# [batch_size, h * w, pool_scale * pool_scale]
affinity_matrix = self.gla(x + resize(
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
).permute(0, 2, 3, 1).reshape(
batch_size, -1, self.pool_scale**2)
affinity_matrix = F.sigmoid(affinity_matrix)
# [batch_size, h * w, channels]
z_out = torch.matmul(affinity_matrix, pooled_x)
# [batch_size, channels, h * w]
z_out = z_out.permute(0, 2, 1).contiguous()
# [batch_size, channels, h, w]
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
z_out = self.residual_conv(z_out)
z_out = F.relu(z_out + x)
if self.fusion:
z_out = self.fusion_conv(z_out)
return z_out
@HEADS.register_module()
class APCHead(BaseDecodeHead):
"""Adaptive Pyramid Context Network for Semantic Segmentation.
This head is the implementation of
`APCNet `_.
Args:
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
Module. Default: (1, 2, 3, 6).
fusion (bool): Add one conv to fuse residual feature.
"""
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
super(APCHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.fusion = fusion
acm_modules = []
for pool_scale in self.pool_scales:
acm_modules.append(
ACM(pool_scale,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.acm_modules = nn.ModuleList(acm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
acm_outs = [x]
for acm_module in self.acm_modules:
acm_outs.append(acm_module(x))
acm_outs = torch.cat(acm_outs, dim=1)
output = self.bottleneck(acm_outs)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/aspp_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class ASPPModule(nn.ModuleList):
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
Args:
dilations (tuple[int]): Dilation rate of each layer.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
act_cfg):
super(ASPPModule, self).__init__()
self.dilations = dilations
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for dilation in dilations:
self.append(
ConvModule(
self.in_channels,
self.channels,
1 if dilation == 1 else 3,
dilation=dilation,
padding=0 if dilation == 1 else dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, x):
"""Forward function."""
aspp_outs = []
for aspp_module in self:
aspp_outs.append(aspp_module(x))
return aspp_outs
@HEADS.register_module()
class ASPPHead(BaseDecodeHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
This head is the implementation of `DeepLabV3
`_.
Args:
dilations (tuple[int]): Dilation rates for ASPP module.
Default: (1, 6, 12, 18).
"""
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
super(ASPPHead, self).__init__(**kwargs)
assert isinstance(dilations, (list, tuple))
self.dilations = dilations
self.image_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.aspp_modules = ASPPModule(
dilations,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
(len(dilations) + 1) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cascade_decode_head.py
================================================
from abc import ABCMeta, abstractmethod
from .decode_head import BaseDecodeHead
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""Base class for cascade decode head used in
:class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs):
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
@abstractmethod
def forward(self, inputs, prev_output):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, prev_output)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cc_head.py
================================================
import torch
from ..builder import HEADS
from .fcn_head import FCNHead
try:
from mmcv.ops import CrissCrossAttention
except ModuleNotFoundError:
CrissCrossAttention = None
@HEADS.register_module()
class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
This head is the implementation of `CCNet
`_.
Args:
recurrence (int): Number of recurrence of Criss Cross Attention
module. Default: 2.
"""
def __init__(self, recurrence=2, **kwargs):
if CrissCrossAttention is None:
raise RuntimeError('Please install mmcv-full for '
'CrissCrossAttention ops')
super(CCHead, self).__init__(num_convs=2, **kwargs)
self.recurrence = recurrence
self.cca = CrissCrossAttention(self.channels)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
for _ in range(self.recurrence):
output = self.cca(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/da_head.py
================================================
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, Scale
from torch import nn
from mmseg.core import add_prefix
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PAM(_SelfAttentionBlock):
"""Position Attention Module (PAM)
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
"""
def __init__(self, in_channels, channels):
super(PAM, self).__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=1,
key_query_norm=False,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=False,
with_out=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
out = super(PAM, self).forward(x, x)
out = self.gamma(out) + x
return out
class CAM(nn.Module):
"""Channel Attention Module (CAM)"""
def __init__(self):
super(CAM, self).__init__()
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
batch_size, channels, height, width = x.size()
proj_query = x.view(batch_size, channels, -1)
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(
energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = F.softmax(energy_new, dim=-1)
proj_value = x.view(batch_size, channels, -1)
out = torch.bmm(attention, proj_value)
out = out.view(batch_size, channels, height, width)
out = self.gamma(out) + x
return out
@HEADS.register_module()
class DAHead(BaseDecodeHead):
"""Dual Attention Network for Scene Segmentation.
This head is the implementation of `DANet
`_.
Args:
pam_channels (int): The channels of Position Attention Module(PAM).
"""
def __init__(self, pam_channels, **kwargs):
super(DAHead, self).__init__(**kwargs)
self.pam_channels = pam_channels
self.pam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam = PAM(self.channels, pam_channels)
self.pam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
self.cam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam = CAM()
self.cam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
def pam_cls_seg(self, feat):
"""PAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.pam_conv_seg(feat)
return output
def cam_cls_seg(self, feat):
"""CAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.cam_conv_seg(feat)
return output
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
pam_feat = self.pam_in_conv(x)
pam_feat = self.pam(pam_feat)
pam_feat = self.pam_out_conv(pam_feat)
pam_out = self.pam_cls_seg(pam_feat)
cam_feat = self.cam_in_conv(x)
cam_feat = self.cam(cam_feat)
cam_feat = self.cam_out_conv(cam_feat)
cam_out = self.cam_cls_seg(cam_feat)
feat_sum = pam_feat + cam_feat
pam_cam_out = self.cls_seg(feat_sum)
return pam_cam_out, pam_out, cam_out
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing, only ``pam_cam`` is used."""
return self.forward(inputs)[0]
def losses(self, seg_logit, seg_label):
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
loss = dict()
loss.update(
add_prefix(
super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
'pam_cam'))
loss.update(
add_prefix(
super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
loss.update(
add_prefix(
super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
return loss
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/decode_head.py
================================================
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
from ..builder import build_loss
from ..losses import accuracy
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False):
super(BaseDecodeHead, self).__init__()
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.loss_decode = build_loss(loss_decode)
self.ignore_index = ignore_index
self.align_corners = align_corners
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
"""Compute segmentation loss."""
loss = dict()
seg_logit = resize(
input=seg_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
loss['loss_seg'] = self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dm_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class DCM(nn.Module):
"""Dynamic Convolutional Module used in DMNet.
Args:
filter_size (int): The filter size of generated convolution kernel
used in Dynamic Convolutional Module.
fusion (bool): Add one conv to fuse DCM output feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super(DCM, self).__init__()
self.filter_size = filter_size
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
0)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.norm_cfg is not None:
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
else:
self.norm = None
self.activate = build_activation_layer(self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
generted_filter = self.filter_gen_conv(
F.adaptive_avg_pool2d(x, self.filter_size))
x = self.input_redu_conv(x)
b, c, h, w = x.shape
# [1, b * c, h, w], c = self.channels
x = x.view(1, b * c, h, w)
# [b * c, 1, filter_size, filter_size]
generted_filter = generted_filter.view(b * c, 1, self.filter_size,
self.filter_size)
pad = (self.filter_size - 1) // 2
if (self.filter_size - 1) % 2 == 0:
p2d = (pad, pad, pad, pad)
else:
p2d = (pad + 1, pad, pad + 1, pad)
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
# [1, b * c, h, w]
output = F.conv2d(input=x, weight=generted_filter, groups=b * c)
# [b, c, h, w]
output = output.view(b, c, h, w)
if self.norm is not None:
output = self.norm(output)
output = self.activate(output)
if self.fusion:
output = self.fusion_conv(output)
return output
@HEADS.register_module()
class DMHead(BaseDecodeHead):
"""Dynamic Multi-scale Filters for Semantic Segmentation.
This head is the implementation of
`DMNet `_.
Args:
filter_sizes (tuple[int]): The size of generated convolutional filters
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
fusion (bool): Add one conv to fuse DCM output feature.
"""
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
super(DMHead, self).__init__(**kwargs)
assert isinstance(filter_sizes, (list, tuple))
self.filter_sizes = filter_sizes
self.fusion = fusion
dcm_modules = []
for filter_size in self.filter_sizes:
dcm_modules.append(
DCM(filter_size,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.dcm_modules = nn.ModuleList(dcm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(filter_sizes) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
dcm_outs = [x]
for dcm_module in self.dcm_modules:
dcm_outs.append(dcm_module(x))
dcm_outs = torch.cat(dcm_outs, dim=1)
output = self.bottleneck(dcm_outs)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dnl_head.py
================================================
import torch
from mmcv.cnn import NonLocal2d
from torch import nn
from ..builder import HEADS
from .fcn_head import FCNHead
class DisentangledNonLocal2d(NonLocal2d):
"""Disentangled Non-Local Blocks.
Args:
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self, *arg, temperature, **kwargs):
super().__init__(*arg, **kwargs)
self.temperature = temperature
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
def embedded_gaussian(self, theta_x, phi_x):
"""Embedded gaussian with temperature."""
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight /= self.temperature
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def forward(self, x):
# x: [N, C, H, W]
n = x.size(0)
# g_x: [N, HxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
# subtract mean
theta_x -= theta_x.mean(dim=-2, keepdim=True)
phi_x -= phi_x.mean(dim=-1, keepdim=True)
pairwise_func = getattr(self, self.mode)
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# y: [N, C, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
# unary_mask: [N, 1, HxW]
unary_mask = self.conv_mask(x)
unary_mask = unary_mask.view(n, 1, -1)
unary_mask = unary_mask.softmax(dim=-1)
# unary_x: [N, 1, C]
unary_x = torch.matmul(unary_mask, g_x)
# unary_x: [N, C, 1, 1]
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
n, self.inter_channels, 1, 1)
output = x + self.conv_out(y + unary_x)
return output
@HEADS.register_module()
class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks.
This head is the implementation of `DNLNet
`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: False.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
temperature=0.05,
**kwargs):
super(DNLHead, self).__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.temperature = temperature
self.dnl_block = DisentangledNonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode,
temperature=self.temperature)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.dnl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ema_head.py
================================================
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.
Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""
def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)
def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)
if self.training:
bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases
return feats_recon
@HEADS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.
This head is the implementation of `EMANet
`_.
Args:
ema_channels (int): EMA module channels
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
concat_input (bool): Whether concat the input and output of convs
before classification layer. Default: True
momentum (float): Momentum to update the base. Default: 0.1.
"""
def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
momentum=0.1,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)
self.ema_in_conv = ConvModule(
self.in_channels,
self.ema_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# project (0, inf) -> (-inf, inf)
self.ema_mid_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False
self.ema_out_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/enc_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss
from .decode_head import BaseDecodeHead
class EncModule(nn.Module):
"""Encoding Module used in EncNet.
Args:
in_channels (int): Input channels.
num_codes (int): Number of code words.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
super(EncModule, self).__init__()
self.encoding_project = ConvModule(
in_channels,
in_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# TODO: resolve this hack
# change to 1d
if norm_cfg is not None:
encoding_norm_cfg = norm_cfg.copy()
if encoding_norm_cfg['type'] in ['BN', 'IN']:
encoding_norm_cfg['type'] += '1d'
else:
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
'2d', '1d')
else:
# fallback to BN1d
encoding_norm_cfg = dict(type='BN1d')
self.encoding = nn.Sequential(
Encoding(channels=in_channels, num_codes=num_codes),
build_norm_layer(encoding_norm_cfg, num_codes)[1],
nn.ReLU(inplace=True))
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels), nn.Sigmoid())
def forward(self, x):
"""Forward function."""
encoding_projection = self.encoding_project(x)
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
batch_size, channels, _, _ = x.size()
gamma = self.fc(encoding_feat)
y = gamma.view(batch_size, channels, 1, 1)
output = F.relu_(x + x * y)
return encoding_feat, output
@HEADS.register_module()
class EncHead(BaseDecodeHead):
"""Context Encoding for Semantic Segmentation.
This head is the implementation of `EncNet
`_.
Args:
num_codes (int): Number of code words. Default: 32.
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
regularize the training. Default: True.
add_lateral (bool): Whether use lateral connection to fuse features.
Default: False.
loss_se_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
"""
def __init__(self,
num_codes=32,
use_se_loss=True,
add_lateral=False,
loss_se_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=0.2),
**kwargs):
super(EncHead, self).__init__(
input_transform='multiple_select', **kwargs)
self.use_se_loss = use_se_loss
self.add_lateral = add_lateral
self.num_codes = num_codes
self.bottleneck = ConvModule(
self.in_channels[-1],
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if add_lateral:
self.lateral_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the last one
self.lateral_convs.append(
ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.enc_module = EncModule(
self.channels,
num_codes=num_codes,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_se_loss:
self.loss_se_decode = build_loss(loss_se_decode)
self.se_layer = nn.Linear(self.channels, self.num_classes)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
feat = self.bottleneck(inputs[-1])
if self.add_lateral:
laterals = [
resize(
lateral_conv(inputs[i]),
size=feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
for i, lateral_conv in enumerate(self.lateral_convs)
]
feat = self.fusion(torch.cat([feat, *laterals], 1))
encode_feat, output = self.enc_module(feat)
output = self.cls_seg(output)
if self.use_se_loss:
se_output = self.se_layer(encode_feat)
return output, se_output
else:
return output
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing, ignore se_loss."""
if self.use_se_loss:
return self.forward(inputs)[0]
else:
return self.forward(inputs)
@staticmethod
def _convert_to_onehot_labels(seg_label, num_classes):
"""Convert segmentation label to onehot.
Args:
seg_label (Tensor): Segmentation label of shape (N, H, W).
num_classes (int): Number of classes.
Returns:
Tensor: Onehot labels of shape (N, num_classes).
"""
batch_size = seg_label.size(0)
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
for i in range(batch_size):
hist = seg_label[i].float().histc(
bins=num_classes, min=0, max=num_classes - 1)
onehot_labels[i] = hist > 0
return onehot_labels
def losses(self, seg_logit, seg_label):
"""Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit
loss = dict()
loss.update(super(EncHead, self).losses(seg_logit, seg_label))
se_loss = self.loss_se_decode(
se_seg_logit,
self._convert_to_onehot_labels(seg_label, self.num_classes))
loss['loss_se'] = se_loss
return loss
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fcn_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation.
This head is implemented of `FCNNet `_.
Args:
num_convs (int): Number of convs in the head. Default: 2.
kernel_size (int): The kernel size for convs in the head. Default: 3.
concat_input (bool): Whether concat the input and output of convs
before classification layer.
"""
def __init__(self,
num_convs=2,
kernel_size=3,
concat_input=True,
**kwargs):
assert num_convs >= 0
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super(FCNHead, self).__init__(**kwargs)
if num_convs == 0:
assert self.in_channels == self.channels
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for i in range(num_convs - 1):
convs.append(
ConvModule(
self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs(x)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fpn_head.py
================================================
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.
This head is the implementation of `Semantic FPN
`_.
Args:
feature_strides (tuple[int]): The strides for input feature maps.
stack_lateral. All strides suppose to be power of 2. The first
one is of largest resolution.
"""
def __init__(self, feature_strides, **kwargs):
super(FPNHead, self).__init__(
input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
self.scale_heads = nn.ModuleList()
for i in range(len(feature_strides)):
head_length = max(
1,
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
scale_head = []
for k in range(head_length):
scale_head.append(
ConvModule(
self.in_channels[i] if k == 0 else self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
self.scale_heads.append(nn.Sequential(*scale_head))
def forward(self, inputs):
x = self._transform_inputs(inputs)
output = self.scale_heads[0](x[0])
for i in range(1, len(self.feature_strides)):
# non inplace
output = output + resize(
self.scale_heads[i](x[i]),
size=output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/gc_head.py
================================================
import torch
from mmcv.cnn import ContextBlock
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class GCHead(FCNHead):
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
This head is the implementation of `GCNet
`_.
Args:
ratio (float): Multiplier of channels ratio. Default: 1/4.
pooling_type (str): The pooling type of context aggregation.
Options are 'att', 'avg'. Default: 'avg'.
fusion_types (tuple[str]): The fusion type for feature fusion.
Options are 'channel_add', 'channel_mul'. Defautl: ('channel_add',)
"""
def __init__(self,
ratio=1 / 4.,
pooling_type='att',
fusion_types=('channel_add', ),
**kwargs):
super(GCHead, self).__init__(num_convs=2, **kwargs)
self.ratio = ratio
self.pooling_type = pooling_type
self.fusion_types = fusion_types
self.gc_block = ContextBlock(
in_channels=self.channels,
ratio=self.ratio,
pooling_type=self.pooling_type,
fusion_types=self.fusion_types)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.gc_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/lraspp_head.py
================================================
import torch
import torch.nn as nn
from mmcv import is_tuple_of
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
This head is the improved implementation of `Searching for MobileNetV3
`_.
Args:
branch_channels (tuple[int]): The number of output channels in every
each branch. Default: (32, 64).
"""
def __init__(self, branch_channels=(32, 64), **kwargs):
super(LRASPPHead, self).__init__(**kwargs)
if self.input_transform != 'multiple_select':
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
f'must be \'multiple_select\'. But received '
f'\'{self.input_transform}\'')
assert is_tuple_of(branch_channels, int)
assert len(branch_channels) == len(self.in_channels) - 1
self.branch_channels = branch_channels
self.convs = nn.Sequential()
self.conv_ups = nn.Sequential()
for i in range(len(branch_channels)):
self.convs.add_module(
f'conv{i}',
nn.Conv2d(
self.in_channels[i], branch_channels[i], 1, bias=False))
self.conv_ups.add_module(
f'conv_up{i}',
ConvModule(
self.channels + branch_channels[i],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False))
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
self.aspp_conv = ConvModule(
self.in_channels[-1],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False)
self.image_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
ConvModule(
self.in_channels[2],
self.channels,
1,
act_cfg=dict(type='Sigmoid'),
bias=False))
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
x = inputs[-1]
x = self.aspp_conv(x) * resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = self.conv_up_input(x)
for i in range(len(self.branch_channels) - 1, -1, -1):
x = resize(
x,
size=inputs[i].size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = torch.cat([x, self.convs[i](inputs[i])], 1)
x = self.conv_ups[i](x)
return self.cls_seg(x)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/nl_head.py
================================================
import torch
from mmcv.cnn import NonLocal2d
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class NLHead(FCNHead):
"""Non-local Neural Networks.
This head is the implementation of `NLNet
`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: True.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
**kwargs):
super(NLHead, self).__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.nl_block = NonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.nl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ocr_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .cascade_decode_head import BaseCascadeDecodeHead
class SpatialGatherModule(nn.Module):
"""Aggregate the context features according to the initial predicted
probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, scale):
super(SpatialGatherModule, self).__init__()
self.scale = scale
def forward(self, feats, probs):
"""Forward function."""
batch_size, num_classes, height, width = probs.size()
channels = feats.size(1)
probs = probs.view(batch_size, num_classes, -1)
feats = feats.view(batch_size, channels, -1)
# [batch_size, height*width, num_classes]
feats = feats.permute(0, 2, 1)
# [batch_size, channels, height*width]
probs = F.softmax(self.scale * probs, dim=2)
# [batch_size, channels, num_classes]
ocr_context = torch.matmul(probs, feats)
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
return ocr_context
class ObjectAttentionBlock(_SelfAttentionBlock):
"""Make a OCR used SelfAttentionBlock."""
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
act_cfg):
if scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=scale)
else:
query_downsample = None
super(ObjectAttentionBlock, self).__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=query_downsample,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=True,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.bottleneck = ConvModule(
in_channels * 2,
in_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, query_feats, key_feats):
"""Forward function."""
context = super(ObjectAttentionBlock,
self).forward(query_feats, key_feats)
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
if self.query_downsample is not None:
output = resize(query_feats)
return output
@HEADS.register_module()
class OCRHead(BaseCascadeDecodeHead):
"""Object-Contextual Representations for Semantic Segmentation.
This head is the implementation of `OCRNet
`_.
Args:
ocr_channels (int): The intermediate channels of OCR block.
scale (int): The scale of probability map in SpatialGatherModule in
Default: 1.
"""
def __init__(self, ocr_channels, scale=1, **kwargs):
super(OCRHead, self).__init__(**kwargs)
self.ocr_channels = ocr_channels
self.scale = scale
self.object_context_block = ObjectAttentionBlock(
self.channels,
self.ocr_channels,
self.scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.spatial_gather_module = SpatialGatherModule(self.scale)
self.bottleneck = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs, prev_output):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.bottleneck(x)
context = self.spatial_gather_module(feats, prev_output)
object_context = self.object_context_block(feats, context)
output = self.cls_seg(object_context)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/point_head.py
================================================
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import point_sample
from mmseg.models.builder import HEADS
from mmseg.ops import resize
from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead
def calculate_uncertainty(seg_logits):
"""Estimate uncertainty based on seg logits.
For each location of the prediction ``seg_logits`` we estimate
uncertainty as the difference between top first and top second
predicted logits.
Args:
seg_logits (Tensor): Semantic segmentation logits,
shape (batch_size, num_classes, height, width).
Returns:
scores (Tensor): T uncertainty scores with the most uncertain
locations having the highest uncertainty score, shape (
batch_size, 1, height, width)
"""
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@HEADS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.
``PointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Default: 3.
in_channels (int): Number of input channels. Default: 256.
fc_channels (int): Number of fc channels. Default: 256.
num_classes (int): Number of classes for logits. Default: 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Default: False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Default: True.
conv_cfg (dict|None): Dictionary to construct and config conv layer.
Default: dict(type='Conv1d'))
norm_cfg (dict|None): Dictionary to construct and config norm layer.
Default: None.
loss_point (dict): Dictionary to construct and config loss layer of
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
loss_weight=1.0).
"""
def __init__(self,
num_fcs=3,
coarse_pred_each_layer=True,
conv_cfg=dict(type='Conv1d'),
norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False),
**kwargs):
super(PointHead, self).__init__(
input_transform='multiple_select',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.num_fcs = num_fcs
self.coarse_pred_each_layer = coarse_pred_each_layer
fc_in_channels = sum(self.in_channels) + self.num_classes
fc_channels = self.channels
self.fcs = nn.ModuleList()
for k in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
else 0
self.fc_seg = nn.Conv1d(
fc_in_channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.fc_seg, std=0.001)
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.fc_seg(feat)
return output
def forward(self, fine_grained_point_feats, coarse_point_feats):
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_point_feats), dim=1)
return self.cls_seg(x)
def _get_fine_grained_point_feats(self, x, points):
"""Sample from fine grained features.
Args:
x (list[Tensor]): Feature pyramid from by neck or backbone.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
fine_grained_feats (Tensor): Sampled fine grained feature,
shape (batch_size, sum(channels of x), num_points).
"""
fine_grained_feats_list = [
point_sample(_, points, align_corners=self.align_corners)
for _ in x
]
if len(fine_grained_feats_list) > 1:
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
else:
fine_grained_feats = fine_grained_feats_list[0]
return fine_grained_feats
def _get_coarse_point_feats(self, prev_output, points):
"""Sample from fine grained features.
Args:
prev_output (list[Tensor]): Prediction of previous decode head.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
num_classes, num_points).
"""
coarse_feats = point_sample(
prev_output, points, align_corners=self.align_corners)
return coarse_feats
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self._transform_inputs(inputs)
with torch.no_grad():
points = self.get_points_train(
prev_output, calculate_uncertainty, cfg=train_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_label = point_sample(
gt_semantic_seg.float(),
points,
mode='nearest',
align_corners=self.align_corners)
point_label = point_label.squeeze(1).long()
losses = self.losses(point_logits, point_label)
return losses
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
refined_seg_logits = prev_output.clone()
for _ in range(test_cfg.subdivision_steps):
refined_seg_logits = resize(
refined_seg_logits,
scale_factor=test_cfg.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_seg_logits = refined_seg_logits.reshape(
batch_size, channels, height * width)
refined_seg_logits = refined_seg_logits.scatter_(
2, point_indices, point_logits)
refined_seg_logits = refined_seg_logits.view(
batch_size, channels, height, width)
return refined_seg_logits
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
loss['loss_point'] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss
def get_points_train(self, seg_logits, uncertainty_func, cfg):
"""Sample points for training.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'uncertainty_func' function that takes point's logit prediction as
input.
Args:
seg_logits (Tensor): Semantic segmentation logits, shape (
batch_size, num_classes, height, width).
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains the coordinates of ``num_points`` sampled
points.
"""
num_points = cfg.num_points
oversample_ratio = cfg.oversample_ratio
importance_sample_ratio = cfg.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=seg_logits.device)
point_logits = point_sample(seg_logits, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=seg_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_point_coords = torch.rand(
batch_size, num_random_points, 2, device=seg_logits.device)
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
return point_coords
def get_points_test(self, seg_logits, uncertainty_func, cfg):
"""Sample points for testing.
Find ``num_points`` most uncertain points from ``uncertainty_map``.
Args:
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
height, width) for class-specific or class-agnostic prediction.
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Testing config of point head.
Returns:
point_indices (Tensor): A tensor of shape (batch_size, num_points)
that contains indices from [0, height x width) of the most
uncertain points.
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
num_points = cfg.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
h_step = 1.0 / height
w_step = 1.0 / width
uncertainty_map = uncertainty_map.view(batch_size, height * width)
num_points = min(height * width, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = torch.zeros(
batch_size,
num_points,
2,
dtype=torch.float,
device=seg_logits.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
width).float() * h_step
return point_indices, point_coords
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psa_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
try:
from mmcv.ops import PSAMask
except ModuleNotFoundError:
PSAMask = None
@HEADS.register_module()
class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing.
This head is the implementation of `PSANet
`_.
Args:
mask_size (tuple[int]): The PSA mask size. It usually equals input
size.
psa_type (str): The type of psa module. Options are 'collect',
'distribute', 'bi-direction'. Default: 'bi-direction'
compact (bool): Whether use compact map for 'collect' mode.
Default: True.
shrink_factor (int): The downsample factors of psa mask. Default: 2.
normalization_factor (float): The normalize factor of attention.
psa_softmax (bool): Whether use softmax for attention.
"""
def __init__(self,
mask_size,
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
**kwargs):
if PSAMask is None:
raise RuntimeError('Please install mmcv-full for PSAMask ops')
super(PSAHead, self).__init__(**kwargs)
assert psa_type in ['collect', 'distribute', 'bi-direction']
self.psa_type = psa_type
self.compact = compact
self.shrink_factor = shrink_factor
self.mask_size = mask_size
mask_h, mask_w = mask_size
self.psa_softmax = psa_softmax
if normalization_factor is None:
normalization_factor = mask_h * mask_w
self.normalization_factor = normalization_factor
self.reduce = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
if psa_type == 'bi-direction':
self.reduce_p = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention_p = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
self.psamask_collect = PSAMask('collect', mask_size)
self.psamask_distribute = PSAMask('distribute', mask_size)
else:
self.psamask = PSAMask(psa_type, mask_size)
self.proj = ConvModule(
self.channels * (2 if psa_type == 'bi-direction' else 1),
self.in_channels,
kernel_size=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
self.in_channels * 2,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
identity = x
align_corners = self.align_corners
if self.psa_type in ['collect', 'distribute']:
out = self.reduce(x)
n, c, h, w = out.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
out = resize(
out,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y = self.attention(out)
if self.compact:
if self.psa_type == 'collect':
y = y.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = self.psamask(y)
if self.psa_softmax:
y = F.softmax(y, dim=1)
out = torch.bmm(
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
else:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
x_col = resize(
x_col,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
x_dis = resize(
x_dis,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = self.psamask_collect(y_col)
y_dis = self.psamask_distribute(y_dis)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
out = torch.cat([x_col, x_dis], 1)
out = self.proj(out)
out = resize(
out,
size=identity.shape[2:],
mode='bilinear',
align_corners=align_corners)
out = self.bottleneck(torch.cat((identity, out), dim=1))
out = self.cls_seg(out)
return out
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psp_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
@HEADS.register_module()
class PSPHead(BaseDecodeHead):
"""Pyramid Scene Parsing Network.
This head is the implementation of
`PSPNet `_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_aspp_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .aspp_head import ASPPHead, ASPPModule
class DepthwiseSeparableASPPModule(ASPPModule):
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
conv."""
def __init__(self, **kwargs):
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
for i, dilation in enumerate(self.dilations):
if dilation > 1:
self[i] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
3,
dilation=dilation,
padding=dilation,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
@HEADS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation.
This head is the implementation of `DeepLabV3+
`_.
Args:
c1_in_channels (int): The input channels of c1 decoder. If is 0,
the no decoder will be used.
c1_channels (int): The intermediate channels of c1 decoder.
"""
def __init__(self, c1_in_channels, c1_channels, **kwargs):
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
assert c1_in_channels >= 0
self.aspp_modules = DepthwiseSeparableASPPModule(
dilations=self.dilations,
in_channels=self.in_channels,
channels=self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if c1_in_channels > 0:
self.c1_bottleneck = ConvModule(
c1_in_channels,
c1_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
self.c1_bottleneck = None
self.sep_bottleneck = nn.Sequential(
DepthwiseSeparableConvModule(
self.channels + c1_channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
DepthwiseSeparableConvModule(
self.channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
if self.c1_bottleneck is not None:
c1_output = self.c1_bottleneck(inputs[0])
output = resize(
input=output,
size=c1_output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = torch.cat([output, c1_output], dim=1)
output = self.sep_bottleneck(output)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_fcn_head.py
================================================
from mmcv.cnn import DepthwiseSeparableConvModule
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.
This head is implemented according to Fast-SCNN paper.
Args:
in_channels(int): Number of output channels of FFM.
channels(int): Number of middle-stage channels in the decode head.
concat_input(bool): Whether to concatenate original decode input into
the result of several consecutive convolution layers.
Default: True.
num_classes(int): Used to determine the dimension of
final prediction tensor.
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
norm_cfg (dict | None): Config of norm layers.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_decode(dict): Config of loss type and some
relevant additional options.
"""
def __init__(self, **kwargs):
super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
self.convs[0] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
for i in range(1, self.num_convs):
self.convs[i] = DepthwiseSeparableConvModule(
self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
if self.concat_input:
self.conv_cat = DepthwiseSeparableConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/uper_head.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .psp_head import PPM
@HEADS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/__init__.py
================================================
from .accuracy import Accuracy, accuracy
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .lovasz_loss import LovaszLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/accuracy.py
================================================
import torch.nn as nn
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
target (torch.Tensor): The target of each prediction, shape (N, , ...)
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == target.ndim + 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
# transpose to shape (maxk, N, ...)
pred_label = pred_label.transpose(0, 1)
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / target.numel()))
return res[0] if return_single else res
class Accuracy(nn.Module):
"""Accuracy calculation module."""
def __init__(self, topk=(1, ), thresh=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/cross_entropy_loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100):
"""The wrapper function for :func:`F.cross_entropy`"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=255):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
Returns:
torch.Tensor: The calculated loss
"""
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/lovasz_loss.py
================================================
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def lovasz_grad(gt_sorted):
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
See Alg. 1 in paper.
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def flatten_binary_logits(logits, labels, ignore_index=None):
"""Flattens predictions in the batch (binary case) Remove labels equal to
'ignore_index'."""
logits = logits.view(-1)
labels = labels.view(-1)
if ignore_index is None:
return logits, labels
valid = (labels != ignore_index)
vlogits = logits[valid]
vlabels = labels[valid]
return vlogits, vlabels
def flatten_probs(probs, labels, ignore_index=None):
"""Flattens predictions in the batch."""
if probs.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probs.size()
probs = probs.view(B, 1, H, W)
B, C, H, W = probs.size()
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
labels = labels.view(-1)
if ignore_index is None:
return probs, labels
valid = (labels != ignore_index)
vprobs = probs[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobs, vlabels
def lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [P], logits at each prediction
(between -infty and +infty).
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
Returns:
torch.Tensor: The calculated loss.
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * signs)
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def lovasz_hinge(logits,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [B, H, W], logits at each pixel
(between -infty and +infty).
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
classes (str | list[int], optional): Placeholder, to be consistent with
other loss. Default: None.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): Placeholder, to be consistent
with other loss. Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_hinge_flat(*flatten_binary_logits(
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
for logit, label in zip(logits, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_hinge_flat(
*flatten_binary_logits(logits, labels, ignore_index))
return loss
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [P, C], class probabilities at each prediction
(between 0 and 1).
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
classes (str | list[int], optional): Classes choosed to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
class_weight (list[float], optional): The weight for each class.
Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
if probs.numel() == 0:
# only void pixels, the gradients should be 0
return probs * 0.
C = probs.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes == 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probs[:, 0]
else:
class_pred = probs[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
if class_weight is not None:
loss *= class_weight[c]
losses.append(loss)
return torch.stack(losses).mean()
def lovasz_softmax(probs,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [B, C, H, W], class probabilities at each
prediction (between 0 and 1).
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
C - 1).
classes (str | list[int], optional): Classes choosed to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_softmax_flat(
*flatten_probs(
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
classes=classes,
class_weight=class_weight)
for prob, label in zip(probs, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_softmax_flat(
*flatten_probs(probs, labels, ignore_index),
classes=classes,
class_weight=class_weight)
return loss
@LOSSES.register_module()
class LovaszLoss(nn.Module):
"""LovaszLoss.
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure in neural
networks `_.
Args:
loss_type (str, optional): Binary or multi-class loss.
Default: 'multi_class'. Options are "binary" and "multi_class".
classes (str | list[int], optional): Classes choosed to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
def __init__(self,
loss_type='multi_class',
classes='present',
per_image=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
super(LovaszLoss, self).__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
if loss_type == 'binary':
self.cls_criterion = lovasz_hinge
else:
self.cls_criterion = lovasz_softmax
assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
if not per_image:
assert reduction == 'none', "reduction should be 'none' when \
per_image is False."
self.classes = classes
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# if multi-class loss, transform logits to probs
if self.cls_criterion == lovasz_softmax:
cls_score = F.softmax(cls_score, dim=1)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
self.classes,
self.per_image,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/utils.py
================================================
import functools
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Avarage factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
if weight.dim() > 1:
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/necks/__init__.py
================================================
from .fpn import FPN
__all__ = ['FPN']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/necks/fpn.py
================================================
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from ..builder import NECKS
@NECKS.register_module()
class FPN(nn.Module):
"""Feature Pyramid Network.
This is an implementation of - Feature Pyramid Networks for Object
Detection (https://arxiv.org/abs/1612.03144)
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, its actual mode is specified by `extra_convs_on_inputs`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
on the original feature from the backbone. If True,
it is equivalent to `add_extra_convs='on_input'`. If False, it is
equivalent to set `add_extra_convs='on_output'`. Default to True.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
if extra_convs_on_inputs:
# For compatibility with previous release
# TODO: deprecate `extra_convs_on_inputs`
self.add_extra_convs = 'on_input'
else:
self.add_extra_convs = 'on_output'
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i], size=prev_shape, **self.upsample_cfg)
# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/__init__.py
================================================
from .cascade_encoder_decoder import CascadeEncoderDecoder
from .encoder_decoder import EncoderDecoder
__all__ = ['EncoderDecoder', 'CascadeEncoderDecoder']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/base.py
================================================
import logging
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import auto_fp16
class BaseSegmentor(nn.Module):
"""Base class for segmentors."""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseSegmentor, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the segmentor has neck"""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_auxiliary_head(self):
"""bool: whether the segmentor has auxiliary head"""
return hasattr(self,
'auxiliary_head') and self.auxiliary_head is not None
@property
def with_decode_head(self):
"""bool: whether the segmentor has decode head"""
return hasattr(self, 'decode_head') and self.decode_head is not None
@abstractmethod
def extract_feat(self, imgs):
"""Placeholder for extract features from images."""
pass
@abstractmethod
def encode_decode(self, img, img_metas):
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
"""Placeholder for Forward function for training."""
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
"""Placeholder for single image test."""
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Placeholder for augmentation test."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in segmentor.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = logging.getLogger()
logger.info(f'load model from: {pretrained}')
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got '
f'{type(var)}')
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(imgs)}) != '
f'num of image meta ({len(img_metas)})')
# all images in the same aug batch all of the same ori_shape and pad
# shape
for img_meta in img_metas:
ori_shapes = [_['ori_shape'] for _ in img_meta]
assert all(shape == ori_shapes[0] for shape in ori_shapes)
img_shapes = [_['img_shape'] for _ in img_meta]
assert all(shape == img_shapes[0] for shape in img_shapes)
pad_shapes = [_['pad_shape'] for _ in img_meta]
assert all(shape == pad_shapes[0] for shape in pad_shapes)
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=('img', ))
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self(**data_batch)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(data_batch['img'].data))
return outputs
def val_step(self, data_batch, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
output = self(**data_batch, **kwargs)
return output
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def show_result(self,
img,
result,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The semantic segmentation results to draw over
`img`.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
if palette is None:
if self.PALETTE is None:
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
else:
palette = self.PALETTE
palette = np.array(palette)
assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/cascade_encoder_decoder.py
================================================
from torch import nn
from mmseg.core import add_prefix
from mmseg.ops import resize
from .. import builder
from ..builder import SEGMENTORS
from .encoder_decoder import EncoderDecoder
@SEGMENTORS.register_module()
class CascadeEncoderDecoder(EncoderDecoder):
"""Cascade Encoder Decoder segmentors.
CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
CascadeEncoderDecoder are cascaded. The output of previous decoder_head
will be the input of next decoder_head.
"""
def __init__(self,
num_stages,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
self.num_stages = num_stages
super(CascadeEncoderDecoder, self).__init__(
backbone=backbone,
decode_head=decode_head,
neck=neck,
auxiliary_head=auxiliary_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
assert isinstance(decode_head, list)
assert len(decode_head) == self.num_stages
self.decode_head = nn.ModuleList()
for i in range(self.num_stages):
self.decode_head.append(builder.build_head(decode_head[i]))
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
self.backbone.init_weights(pretrained=pretrained)
for i in range(self.num_stages):
self.decode_head[i].init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(img)
out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
for i in range(1, self.num_stages):
out = self.decode_head[i].forward_test(x, out, img_metas,
self.test_cfg)
out = resize(
input=out,
size=img.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head[0].forward_train(
x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode_0'))
for i in range(1, self.num_stages):
# forward test again, maybe unnecessary for most methods.
prev_outputs = self.decode_head[i - 1].forward_test(
x, img_metas, self.test_cfg)
loss_decode = self.decode_head[i].forward_train(
x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_decode, f'decode_{i}'))
return losses
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/encoder_decoder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.core import add_prefix
from mmseg.ops import resize
from .. import builder
from ..builder import SEGMENTORS
from .base import BaseSegmentor
@SEGMENTORS.register_module()
class EncoderDecoder(BaseSegmentor):
"""Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
"""
def __init__(self,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(EncoderDecoder, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(builder.build_head(head_cfg))
else:
self.auxiliary_head = builder.build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
out = resize(
input=out,
size=img.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas,
gt_semantic_seg,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return seg_logits
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas,
gt_semantic_seg,
self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.forward_train(
x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def forward_dummy(self, img):
"""Dummy forward function."""
seg_logit = self.encode_decode(img, None)
return seg_logit
def forward_train(self, img, img_metas, gt_semantic_seg):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas,
gt_semantic_seg)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, img_metas, gt_semantic_seg)
losses.update(loss_aux)
return losses
# TODO refactor
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.encode_decode(crop_img, img_meta)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
preds = resize(
preds,
size=img_meta[0]['ori_shape'][:2],
mode='bilinear',
align_corners=self.align_corners,
warning=False)
return preds
def whole_inference(self, img, img_meta, rescale):
"""Inference with full image."""
seg_logit = self.encode_decode(img, img_meta)
if rescale:
seg_logit = resize(
seg_logit,
size=img_meta[0]['ori_shape'][:2],
mode='bilinear',
align_corners=self.align_corners,
warning=False)
return seg_logit
def inference(self, img, img_meta, rescale):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output segmentation map.
"""
assert self.test_cfg.mode in ['slide', 'whole']
ori_shape = img_meta[0]['ori_shape']
assert all(_['ori_shape'] == ori_shape for _ in img_meta)
if self.test_cfg.mode == 'slide':
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
flip = img_meta[0]['flip']
if flip:
flip_direction = img_meta[0]['flip_direction']
assert flip_direction in ['horizontal', 'vertical']
if flip_direction == 'horizontal':
output = output.flip(dims=(3, ))
elif flip_direction == 'vertical':
output = output.flip(dims=(2, ))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/__init__.py
================================================
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .res_layer import ResLayer
from .self_attention_block import SelfAttentionBlock
from .up_conv_block import UpConvBlock
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3'
]
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/inverted_residual.py
================================================
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.utils import checkpoint as cp
from .se_layer import SELayer
class InvertedResidual(nn.Module):
"""InvertedResidual block for MobileNetV2.
Args:
in_channels (int): The input channels of the InvertedResidual block.
out_channels (int): The output channels of the InvertedResidual block.
stride (int): Stride of the middle (first) 3x3 convolution.
expand_ratio (int): Adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
dilation (int): Dilation rate of depthwise conv. Default: 1
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
stride,
expand_ratio,
dilation=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
with_cp=False):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], f'stride must in [1, 2]. ' \
f'But received {stride}.'
self.with_cp = with_cp
self.use_res_connect = self.stride == 1 and in_channels == out_channels
hidden_dim = int(round(in_channels * expand_ratio))
layers = []
if expand_ratio != 1:
layers.append(
ConvModule(
in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
layers.extend([
ConvModule(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
groups=hidden_dim,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=hidden_dim,
out_channels=out_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
def _inner_forward(x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class InvertedResidualV3(nn.Module):
"""Inverted Residual Block for MobileNetV3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
mid_channels (int): The input channels of the depthwise convolution.
kernel_size (int): The kernal size of the depthwise convolution.
Default: 3.
stride (int): The stride of the depthwise convolution. Default: 1.
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
se layer.
with_expand_conv (bool): Use expand conv or not. If set False,
mid_channels must be the same with in_channels. Default: True.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels,
kernel_size=3,
stride=1,
se_cfg=None,
with_expand_conv=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(InvertedResidualV3, self).__init__()
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp
self.with_se = se_cfg is not None
self.with_expand_conv = with_expand_conv
if self.with_se:
assert isinstance(se_cfg, dict)
if not self.with_expand_conv:
assert mid_channels == in_channels
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=mid_channels,
conv_cfg=dict(
type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(**se_cfg)
self.linear_conv = ConvModule(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + out
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/make_divisible.py
================================================
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
"""Make divisible function.
This function rounds the channel number to the nearest value that can be
divisible by the divisor. It is taken from the original tf repo. It ensures
that all layers have a channel number that is divisible by divisor. It can
be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int): The minimum value of the output channel.
Default: None, means that the minimum value equal to the divisor.
min_ratio (float): The minimum ratio of the rounded channel number to
the original channel number. Default: 0.9.
Returns:
int: The modified output channel number.
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/res_layer.py
================================================
from mmcv.cnn import build_conv_layer, build_norm_layer
from torch import nn as nn
class ResLayer(nn.Sequential):
"""ResLayer to build ResNet style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
multi_grid (int | None): Multi grid dilation rates of last
stage. Default: None
contract_dilation (bool): Whether contract first dilation of each layer
Default: False
"""
def __init__(self,
block,
inplanes,
planes,
num_blocks,
stride=1,
dilation=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
multi_grid=None,
contract_dilation=False,
**kwargs):
self.block = block
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = []
conv_stride = stride
if avg_down:
conv_stride = 1
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
if multi_grid is None:
if dilation > 1 and contract_dilation:
first_dilation = dilation // 2
else:
first_dilation = dilation
else:
first_dilation = multi_grid[0]
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=first_dilation,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation if multi_grid is None else multi_grid[i],
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
super(ResLayer, self).__init__(*layers)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/se_layer.py
================================================
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule
from .make_divisible import make_divisible
class SELayer(nn.Module):
"""Squeeze-and-Excitation Module.
Args:
channels (int): The input (and output) channels of the SE layer.
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
``int(channels/ratio)``. Default: 16.
conv_cfg (None or dict): Config dict for convolution layer.
Default: None, which means using conv2d.
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
If act_cfg is a dict, two activation layers will be configurated
by this dict. If act_cfg is a sequence of dicts, the first
activation layer will be configurated by the first dict and the
second activation layer will be configurated by the second dict.
Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
divisor=6.0)).
"""
def __init__(self,
channels,
ratio=16,
conv_cfg=None,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0))):
super(SELayer, self).__init__()
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert mmcv.is_tuple_of(act_cfg, dict)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = ConvModule(
in_channels=channels,
out_channels=make_divisible(channels // ratio, 8),
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[0])
self.conv2 = ConvModule(
in_channels=make_divisible(channels // ratio, 8),
out_channels=channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[1])
def forward(self, x):
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.conv2(out)
return x * out
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/self_attention_block.py
================================================
import torch
from mmcv.cnn import ConvModule, constant_init
from torch import nn as nn
from torch.nn import functional as F
class SelfAttentionBlock(nn.Module):
"""General self-attention block/non-local block.
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
query and value.
Args:
key_in_channels (int): Input channels of key feature.
query_in_channels (int): Input channels of query feature.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_downsample (nn.Module): Query downsample module.
key_downsample (nn.Module): Key downsample module.
key_query_num_convs (int): Number of convs for key/query projection.
value_num_convs (int): Number of convs for value projection.
matmul_norm (bool): Whether normalize attention map with sqrt of
channels
with_out (bool): Whether use out projection.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, key_in_channels, query_in_channels, channels,
out_channels, share_key_query, query_downsample,
key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out,
conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__()
if share_key_query:
assert key_in_channels == query_in_channels
self.key_in_channels = key_in_channels
self.query_in_channels = query_in_channels
self.out_channels = out_channels
self.channels = channels
self.share_key_query = share_key_query
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.key_project = self.build_project(
key_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if share_key_query:
self.query_project = self.key_project
else:
self.query_project = self.build_project(
query_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.value_project = self.build_project(
key_in_channels,
channels if with_out else out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if with_out:
self.out_project = self.build_project(
channels,
out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.out_project = None
self.query_downsample = query_downsample
self.key_downsample = key_downsample
self.matmul_norm = matmul_norm
self.init_weights()
def init_weights(self):
"""Initialize weight of later layer."""
if self.out_project is not None:
if not isinstance(self.out_project, ConvModule):
constant_init(self.out_project, 0)
def build_project(self, in_channels, channels, num_convs, use_conv_module,
conv_cfg, norm_cfg, act_cfg):
"""Build projection layer for key/query/value/out."""
if use_conv_module:
convs = [
ConvModule(
in_channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
]
for _ in range(num_convs - 1):
convs.append(
ConvModule(
channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
convs = [nn.Conv2d(in_channels, channels, 1)]
for _ in range(num_convs - 1):
convs.append(nn.Conv2d(channels, channels, 1))
if len(convs) > 1:
convs = nn.Sequential(*convs)
else:
convs = convs[0]
return convs
def forward(self, query_feats, key_feats):
"""Forward function."""
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous()
key = self.key_project(key_feats)
value = self.value_project(key_feats)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.reshape(*key.shape[:2], -1)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous()
sim_map = torch.matmul(query, key)
if self.matmul_norm:
sim_map = (self.channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
if self.out_project is not None:
context = self.out_project(context)
return context
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/up_conv_block.py
================================================
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_upsample_layer
class UpConvBlock(nn.Module):
"""Upsample convolution block in decoder for UNet.
This upsample convolution block consists of one upsample module
followed by one convolution block. The upsample module expands the
high-level low-resolution feature map and the convolution block fuses
the upsampled high-level low-resolution feature map and the low-level
high-resolution feature map from encoder.
Args:
conv_block (nn.Sequential): Sequential of convolutional layers.
in_channels (int): Number of input channels of the high-level
skip_channels (int): Number of input channels of the low-level
high-resolution feature map from encoder.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers in the conv_block.
Default: 2.
stride (int): Stride of convolutional layer in conv_block. Default: 1.
dilation (int): Dilation rate of convolutional layer in conv_block.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv'). If the size of
high-level feature map is the same as that of skip feature map
(low-level feature map from encoder), it does not need upsample the
high-level feature map and the upsample_cfg is None.
dcn (bool): Use deformable convoluton in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
conv_block,
in_channels,
skip_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
dcn=None,
plugins=None):
super(UpConvBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.conv_block = conv_block(
in_channels=2 * skip_channels,
out_channels=out_channels,
num_convs=num_convs,
stride=stride,
dilation=dilation,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None)
if upsample_cfg is not None:
self.upsample = build_upsample_layer(
cfg=upsample_cfg,
in_channels=in_channels,
out_channels=skip_channels,
with_cp=with_cp,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.upsample = ConvModule(
in_channels,
skip_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, skip, x):
"""Forward function."""
x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/ops/__init__.py
================================================
from .encoding import Encoding
from .wrappers import Upsample, resize
__all__ = ['Upsample', 'resize', 'Encoding']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/ops/encoding.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
class Encoding(nn.Module):
"""Encoding Layer: a learnable residual encoder.
Input is of shape (batch_size, channels, height, width).
Output is of shape (batch_size, num_codes, channels).
Args:
channels: dimension of the features or feature channels
num_codes: number of code words
"""
def __init__(self, channels, num_codes):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5)
# [num_codes, channels]
self.codewords = nn.Parameter(
torch.empty(num_codes, channels,
dtype=torch.float).uniform_(-std, std),
requires_grad=True)
# [num_codes]
self.scale = nn.Parameter(
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
requires_grad=True)
@staticmethod
def scaled_l2(x, codewords, scale):
num_codes, channels = codewords.size()
batch_size = x.size(0)
reshaped_scale = scale.view((1, 1, num_codes))
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
scaled_l2_norm = reshaped_scale * (
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
return scaled_l2_norm
@staticmethod
def aggregate(assigment_weights, x, codewords):
num_codes, channels = codewords.size()
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
batch_size = x.size(0)
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
encoded_feat = (assigment_weights.unsqueeze(3) *
(expanded_x - reshaped_codewords)).sum(dim=1)
return encoded_feat
def forward(self, x):
assert x.dim() == 4 and x.size(1) == self.channels
# [batch_size, channels, height, width]
batch_size = x.size(0)
# [batch_size, height x width, channels]
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
# assignment_weights: [batch_size, channels, num_codes]
assigment_weights = F.softmax(
self.scaled_l2(x, self.codewords, self.scale), dim=2)
# aggregate
encoded_feat = self.aggregate(assigment_weights, x, self.codewords)
return encoded_feat
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
f'x{self.channels})'
return repr_str
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/ops/wrappers.py
================================================
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)
class Upsample(nn.Module):
def __init__(self,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None):
super(Upsample, self).__init__()
self.size = size
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
if not self.size:
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
else:
size = self.size
return resize(x, size, None, self.mode, self.align_corners)
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/utils/__init__.py
================================================
from .collect_env import collect_env
from .logger import get_root_logger
__all__ = ['get_root_logger', 'collect_env']
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/utils/collect_env.py
================================================
from mmcv.utils import collect_env as collect_base_env
from mmcv.utils import get_git_hash
import mmseg
def collect_env():
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
return env_info
if __name__ == '__main__':
for name, val in collect_env().items():
print('{}: {}'.format(name, val))
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/utils/logger.py
================================================
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmseg".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
return logger
================================================
FILE: downstream_tasks/semantic_segmentation/mmseg/version.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
__version__ = '0.11.0'
def parse_version_info(version_str):
version_info = []
for x in version_str.split('.'):
if x.isdigit():
version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
version_info.append(int(patch_version[0]))
version_info.append(f'rc{patch_version[1]}')
return tuple(version_info)
version_info = parse_version_info(__version__)
================================================
FILE: downstream_tasks/semantic_segmentation/tools/dist_test.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=7956
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=$GPUS \
--master_port=$PORT \
tools/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
================================================
FILE: downstream_tasks/semantic_segmentation/tools/dist_train.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=7956
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=$GPUS \
--master_port=$PORT \
tools/train.py $CONFIG --launcher pytorch ${@:3} \
================================================
FILE: downstream_tasks/semantic_segmentation/tools/test.py
================================================
import argparse
import os
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.utils import DictAction
from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
from backbone import beit
from backbone import cae
def parse_args():
parser = argparse.ArgumentParser(
description='mmseg test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--aug-test', action='store_true', help='Use Flip and Multi scale aug')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
' for generic datasets, and "cityscapes" for Cityscapes')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu_collect is not specified')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.aug_test:
# hard code index
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE']
efficient_test = False
if args.eval_options is not None:
efficient_test = args.eval_options.get('efficient_test', False)
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect, efficient_test)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
if __name__ == '__main__':
main()
================================================
FILE: downstream_tasks/semantic_segmentation/tools/train.py
================================================
import argparse
import copy
import os
import os.path as osp
import time
import mmcv
import mmcv_custom
import torch
from mmcv.runner import init_dist
from mmcv.utils import Config, DictAction, get_git_hash
from mmseg import __version__
from mmseg.apis import set_random_seed
from mmcv_custom import train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger
from backbone import beit
from backbone import mae
from backbone import beit_fapn
from backbone import cae
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.load_from is not None:
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
logger.info(model)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmseg version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_segmentor(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
================================================
FILE: furnace/dataset_folder.py
================================================
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import random
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
while True:
try:
path, target = self.samples[index]
sample = self.loader(path)
break
except Exception as e:
print(e)
index = random.randint(0, len(self.samples) - 1)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
================================================
FILE: furnace/datasets.py
================================================
import os
import torch
from torchvision import datasets, transforms
from timm.data.constants import \
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from furnace.transforms import RandomResizedCropAndInterpolationWithTwoPic
from timm.data import create_transform
from dall_e.utils import map_pixels
from furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator
from furnace.dataset_folder import ImageFolder
def preprocess_vqgan(x):
x = 2.*x - 1.
return x
class DataAugmentationForCAE(object):
def __init__(self, args):
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
if args.color_jitter > 0:
self.common_transform = transforms.Compose([
transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter),
transforms.RandomHorizontalFlip(p=0.5),
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, second_size=args.second_input_size,
interpolation=args.train_interpolation, second_interpolation=args.second_interpolation,
scale=(args.crop_min_size, args.crop_max_size),
),
])
else:
self.common_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, second_size=args.second_input_size,
interpolation=args.train_interpolation, second_interpolation=args.second_interpolation,
scale=(args.crop_min_size, args.crop_max_size),
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
])
if args.discrete_vae_type == "dall-e":
self.visual_token_transform = transforms.Compose([
transforms.ToTensor(),
map_pixels,
])
elif args.discrete_vae_type == "vqgan_gumbel_f8_8192":
self.visual_token_transform = transforms.Compose([
transforms.ToTensor(),
preprocess_vqgan,
])
elif args.discrete_vae_type == "customized":
self.visual_token_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN,
std=IMAGENET_INCEPTION_STD,
),
])
else:
raise NotImplementedError()
if args.mask_generator == 'block':
self.masked_position_generator = MaskingGenerator(
args.window_size, num_masking_patches=args.num_mask_patches,
max_num_patches=args.max_mask_patches_per_block,
min_num_patches=args.min_mask_patches_per_block,
)
elif args.mask_generator == 'random':
self.masked_position_generator = RandomMaskingGenerator(
args.window_size, ratio_masking_patches=args.ratio_mask_patches
)
def __call__(self, image):
for_patches, for_visual_tokens = self.common_transform(image)
return \
self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \
self.masked_position_generator()
def __repr__(self):
repr = "(DataAugmentationForCAE,\n"
repr += " common_transform = %s,\n" % str(self.common_transform)
repr += " patch_transform = %s,\n" % str(self.patch_transform)
repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform)
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
repr += ")"
return repr
def build_cae_pretraining_dataset(args):
transform = DataAugmentationForCAE(args)
print("Data Aug = %s" % str(transform))
return ImageFolder(args.data_path, transform=transform)
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
print("Transform = ")
if isinstance(transform, tuple):
for trans in transform:
print(" - - - - - - - - - - ")
for t in trans.transforms:
print(t)
else:
for t in transform.transforms:
print(t)
print("---------------------------")
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == "image_folder":
root = args.data_path if is_train else args.eval_data_path
dataset = ImageFolder(root, transform=transform)
nb_classes = args.nb_classes
assert len(dataset.class_to_idx) == nb_classes
else:
raise NotImplementedError()
assert nb_classes == args.nb_classes
print("Number of the class = %d" % args.nb_classes)
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
if args.crop_pct is None:
if args.input_size < 384:
args.crop_pct = 224 / 256
else:
args.crop_pct = 1.0
size = int(args.input_size / args.crop_pct)
t.append(
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
================================================
FILE: furnace/engine_for_finetuning.py
================================================
import math
import sys
import time
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
import furnace.utils as utils
def train_class_batch(model, samples, target, criterion):
outputs = model(samples)
loss = criterion(outputs, target)
return loss, outputs
def get_loss_scale_for_deepspeed(model):
optimizer = model.optimizer
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
num_training_steps_per_epoch=None, update_freq=None):
model.train(True)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
if loss_scaler is None:
model.zero_grad()
model.micro_steps = 0
else:
optimizer.zero_grad()
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
step = data_iter_step // update_freq
if step >= num_training_steps_per_epoch:
continue
it = start_steps + step # global training iteration
# Update LR & WD for the first acc
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
if "lr_scale" in param_group:
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
else:
param_group["lr"] = lr_schedule_values[it]
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
param_group["weight_decay"] = wd_schedule_values[it]
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if loss_scaler is None:
samples = samples.half()
loss, output = train_class_batch(
model, samples, targets, criterion)
else:
with torch.cuda.amp.autocast():
loss, output = train_class_batch(
model, samples, targets, criterion)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
if loss_scaler is None:
loss /= update_freq
model.backward(loss)
model.step()
if (data_iter_step + 1) % update_freq == 0:
# model.zero_grad()
# Deepspeed will call step() & model.zero_grad() automatic
if model_ema is not None:
model_ema.update(model)
grad_norm = None
loss_scale_value = get_loss_scale_for_deepspeed(model)
else:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
loss /= update_freq
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order,
update_grad=(data_iter_step + 1) % update_freq == 0)
if (data_iter_step + 1) % update_freq == 0:
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
loss_scale_value = loss_scaler.state_dict()["scale"]
torch.cuda.synchronize()
if mixup_fn is None:
class_acc = (output.max(-1)[-1] == targets).float().mean()
else:
class_acc = None
metric_logger.update(loss=loss_value)
metric_logger.update(class_acc=class_acc)
metric_logger.update(loss_scale=loss_scale_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
log_writer.update(loss=loss_value, head="loss")
log_writer.update(class_acc=class_acc, head="loss")
log_writer.update(loss_scale=loss_scale_value, head="opt")
log_writer.update(lr=max_lr, head="opt")
log_writer.update(min_lr=min_lr, head="opt")
log_writer.update(weight_decay=weight_decay_value, head="opt")
log_writer.update(grad_norm=grad_norm, head="opt")
log_writer.set_step()
# gather the stats from all processes
metric_logger.synchronize_between_processes()
now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(now_time, "Averaged stats:", metric_logger)
# print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: furnace/engine_for_pretraining.py
================================================
import math
import sys
import time
from typing import Iterable
import torch
import torch.nn as nn
import furnace.utils as utils
import torch.nn.functional as F
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def loss_selector(loss_type, pred, target):
if loss_type == 'mse':
return F.mse_loss(pred, target, reduction="mean")
elif loss_type == 'kld':
return F.kl_div(F.log_softmax(pred, dim=-1), F.softmax(target, dim=-1), reduction='mean')
def train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
log_writer=None, lr_scheduler=None, start_steps=None,
lr_schedule_values=None, wd_schedule_values=None, args=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# assign learning rate & weight decay for each step
it = start_steps + step # global training iteration
if lr_schedule_values is not None or wd_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
param_group["weight_decay"] = wd_schedule_values[it]
samples, images, bool_masked_pos = batch
images = images.to(device, non_blocking=True)
samples = samples.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True)
with torch.no_grad():
input_ids = d_vae.get_codebook_indices(images).flatten(1)
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
labels = input_ids[bool_masked_pos]
with torch.cuda.amp.autocast():
outputs, latent, latent_target = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False)
loss_main = nn.CrossEntropyLoss()(input=outputs.float(), target=labels)
loss_align = args.align_loss_weight * loss_selector('mse', latent.float(), latent_target.detach().float())
loss = loss_main + loss_align
loss_value = loss.item()
loss_main_value = loss_main.item()
loss_align_value = loss_align.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order)
loss_scale_value = loss_scaler.state_dict()["scale"]
torch.cuda.synchronize()
mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item()
metric_logger.update(mlm_acc=mlm_acc)
if log_writer is not None:
log_writer.update(mlm_acc=mlm_acc, head="loss")
metric_logger.update(loss=loss_value)
metric_logger.update(loss_main=loss_main_value)
metric_logger.update(loss_align=loss_align_value)
metric_logger.update(loss_scale=loss_scale_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
log_writer.update(loss=loss_value, head="loss")
log_writer.update(loss=loss_main_value, head="loss_main")
log_writer.update(loss=loss_align_value, head="loss_align")
log_writer.update(loss_scale=loss_scale_value, head="opt")
log_writer.update(lr=max_lr, head="opt")
log_writer.update(min_lr=min_lr, head="opt")
log_writer.update(weight_decay=weight_decay_value, head="opt")
log_writer.update(grad_norm=grad_norm, head="opt")
log_writer.set_step()
if lr_scheduler is not None:
lr_scheduler.step_update(start_steps + step)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(now_time, "Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: furnace/masking_generator.py
================================================
import random
import math
import numpy as np
class MaskingGenerator:
def __init__(
self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
min_aspect=0.3, max_aspect=None):
if not isinstance(input_size, tuple):
input_size = (input_size, ) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height, self.width, self.min_num_patches, self.max_num_patches,
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top: top + h, left: left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count != self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
mask_count += delta
return mask
class RandomMaskingGenerator:
def __init__(
self, input_size, ratio_masking_patches):
if not isinstance(input_size, tuple):
input_size = (input_size, ) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = int(ratio_masking_patches * self.num_patches)
def __repr__(self):
repr_str = "Maks: total patches {}, mask patches {}".format(
self.num_patches, self.num_masking_patches
)
return repr_str
def __call__(self):
mask = np.hstack([
np.zeros(self.num_patches - self.num_masking_patches),
np.ones(self.num_masking_patches),
])
np.random.shuffle(mask)
return mask
================================================
FILE: furnace/optim_factory.py
================================================
import torch
from torch import optim as optim
from timm.optim.adafactor import Adafactor
from timm.optim.adahessian import Adahessian
from timm.optim.adamp import AdamP
from timm.optim.lookahead import Lookahead
from timm.optim.nadam import Nadam
from timm.optim.novograd import NovoGrad
from timm.optim.nvnovograd import NvNovoGrad
from timm.optim.radam import RAdam
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.sgdp import SGDP
import json
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def get_num_layer_for_vit(var_name, num_max_layer):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("rel_pos_bias"):
return num_max_layer - 1
elif var_name.startswith("blocks"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return num_max_layer - 1
class LayerDecayValueAssigner(object):
def __init__(self, values):
self.values = values
def get_scale(self, layer_id):
return self.values[layer_id]
def get_layer_id(self, var_name):
return get_num_layer_for_vit(var_name, len(self.values))
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
if weight_decay and filter_bias_and_bn:
skip = {}
if skip_list is not None:
skip = skip_list
elif hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
weight_decay = 0.
else:
parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'nadam':
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adafactor':
if not args.lr:
opt_args['lr'] = None
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, **opt_args)
elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'fusedsgd':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, **opt_args)
elif opt_lower == 'fusednovograd':
opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer
================================================
FILE: furnace/transforms.py
================================================
import torch
import torchvision.transforms.functional as F
from PIL import Image
import warnings
import math
import random
import numpy as np
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
if self.second_size is None:
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
else:
return F.resized_crop(img, i, j, h, w, self.size, interpolation), \
F.resized_crop(img, i, j, h, w, self.second_size, self.second_interpolation)
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
================================================
FILE: furnace/utils.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, DINO and DeiT code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import io
import os
import math
import time
import json
from collections import defaultdict, deque
import datetime
import numpy as np
from timm.utils import get_state_dict
from pathlib import Path
import torch
import torch.distributed as dist
from torch._six import inf
from models.modeling_discrete_vae import Dalle_VAE, DiscreteVAE, VGGAN
from tensorboardX import SummaryWriter
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _NormBase
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class TensorboardLogger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
def flush(self):
self.writer.flush()
def _load_checkpoint_for_ema(model_ema, checkpoint):
"""
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
"""
mem_file = io.BytesIO()
torch.save(checkpoint, mem_file)
mem_file.seek(0)
model_ema._load_checkpoint(mem_file)
def setup_for_distributed_each_gpu(rank):
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
builtin_print('rank is: ', rank, end=' ')
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
if not args.enable_multi_print:
setup_for_distributed(args.rank == 0)
else:
setup_for_distributed_each_gpu(args.rank)
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix=prefix)
warn_missing_keys = []
ignore_missing_keys = []
for key in missing_keys:
keep_flag = True
for ignore_key in ignore_missing.split('|'):
if ignore_key in key:
keep_flag = False
break
if keep_flag:
warn_missing_keys.append(key)
else:
ignore_missing_keys.append(key)
missing_keys = warn_missing_keys
if len(missing_keys) > 0:
print("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(ignore_missing_keys) > 0:
print("Ignored weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, ignore_missing_keys))
if len(error_msgs) > 0:
print('\n'.join(error_msgs))
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array(
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, exp_name=None):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
if exp_name is not None:
checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))]
else:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save_state_dict = model_without_ddp.state_dict()
# all_keys = list(state_dict.keys())
for key in list(to_save_state_dict.keys()):
if key.startswith('teacher.'):
to_save_state_dict.pop(key)
to_save = {
'model': to_save_state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
if model_ema is not None:
to_save['model_ema'] = get_state_dict(model_ema)
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
if model_ema is not None:
client_state['model_ema'] = get_state_dict(model_ema)
if exp_name is not None:
model.save_checkpoint(save_dir=args.output_dir, tag="{}_checkpoint-{}".format(exp_name, epoch_name), client_state=client_state)
else:
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
output_dir = Path(args.output_dir)
if loss_scaler is not None:
# torch.amp
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
# handle ema model
need_state_dict = model_without_ddp.state_dict()
need_ema = False
for key in need_state_dict.keys():
if 'teacher' in key:
need_ema = True
break
checkpoint_model = checkpoint['model']
if need_ema:
all_keys = list(checkpoint_model.keys())
all_keys = [key for key in all_keys if key.startswith('encoder.')]
for key in all_keys:
new_key = key.replace('encoder.','teacher.')
checkpoint_model[new_key] = checkpoint_model[key]
model_without_ddp.load_state_dict(checkpoint_model)
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if hasattr(args, 'model_ema') and args.model_ema:
_load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
else:
# deepspeed, only support '--auto_resume'.
if args.auto_resume:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
print("Auto resume checkpoint: %d" % latest_ckpt)
_, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
args.start_epoch = client_states['epoch'] + 1
if model_ema is not None:
if args.model_ema:
_load_checkpoint_for_ema(model_ema, client_states['model_ema'])
def create_d_vae(weight_path, d_vae_type, image_size, device, args=None):
if d_vae_type == "dall-e":
return get_dalle_vae(weight_path, image_size, device)
if d_vae_type == "vqgan_gumbel_f8_8192":
return get_vqgan_gumbel_f8_8192(weight_path, image_size, device)
elif d_vae_type == "customized":
return get_d_vae(weight_path, image_size, device, args)
elif d_vae_type == "to_tensor":
return None
else:
raise NotImplementedError()
def get_vqgan_gumbel_f8_8192(weight_path, image_size, device):
with torch.no_grad():
vqgan = VGGAN(image_size)
vqgan.load_model(weight_path, device)
return vqgan
def get_dalle_vae(weight_path, image_size, device):
vae = Dalle_VAE(image_size)
vae.load_model(model_dir=weight_path, device=device)
return vae
def get_d_vae(weight_path, image_size, device, args):
NUM_TOKENS = 8192
NUM_LAYERS = args.dvae_num_layers
EMB_DIM = 512
HID_DIM = 256
state_dict = torch.load(weight_path, map_location="cpu")["model"]
model = DiscreteVAE(
image_size=image_size,
num_layers=NUM_LAYERS,
num_tokens=NUM_TOKENS,
codebook_dim=EMB_DIM,
hidden_dim=HID_DIM,
).to(device)
model.load_state_dict(state_dict)
return model
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * args.update_freq * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
0.9,
0.999
],
"eps": 1e-8
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 7,
"loss_scale_window": 128
}
}
writer.write(json.dumps(ds_config, indent=2))
class LP_BatchNorm(_NormBase):
""" A variant used in linear probing.
To freeze parameters (normalization operator specifically), model set to eval mode during linear probing.
According to paper, an extra BN is used on the top of encoder to calibrate the feature magnitudes.
In addition to self.training, we set another flag in this implement to control BN's behavior to train in eval mode.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(LP_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
def forward(self, input, is_train):
"""
We use is_train instead of self.training.
"""
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
# if self.training and self.track_running_stats:
if is_train and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if is_train:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not is_train or self.track_running_stats else None,
self.running_var if not is_train or self.track_running_stats else None,
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
================================================
FILE: linear_util/crop.py
================================================
import math
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
class RandomResizedCrop(transforms.RandomResizedCrop):
"""
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
"""
@staticmethod
def get_params(img, scale, ratio):
width, height = F._get_image_size(img)
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w
================================================
FILE: linear_util/datasets.py
================================================
import os
import PIL
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator
class DataAugmentationMySelf(object):
def __init__(self, args):
self.patch_transform = transforms.Compose([
transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if args.mask_generator == 'block':
self.masked_position_generator = MaskingGenerator(
args.window_size, num_masking_ratio=args.mask_ratio,
max_num_patches=args.max_mask_patches_per_block,
min_num_patches=args.min_mask_patches_per_block,
)
elif args.mask_generator == 'random':
self.masked_position_generator = RandomMaskingGenerator(
args.window_size, ratio_masking_patches=args.mask_ratio
)
def __call__(self, image):
return self.patch_transform(image), self.masked_position_generator()
def __repr__(self):
repr = "(DataAugmentationMySelf,\n"
repr += " patch_transform = %s,\n" % str(self.patch_transform)
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
repr += ")"
return repr
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset
def build_dataset_finetune(is_train, args):
transform = build_transform_finetune(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset
def build_transform_finetune(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
return DataAugmentationMySelf(args, transform)
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
================================================
FILE: linear_util/engine_finetune.py
================================================
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy
import linear_util.misc as misc
import linear_util.lr_sched as lr_sched
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: linear_util/lars.py
================================================
import torch
class LARS(torch.optim.Optimizer):
"""
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
"""
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['trust_coefficient'] * param_norm / update_norm), one),
one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
================================================
FILE: linear_util/lr_decay.py
================================================
import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
param_group_names = {}
param_groups = {}
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if p.ndim == 1 or n in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_group_names:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["params"].append(n)
param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def get_layer_id_for_vit(name, num_layers):
if name in ['cls_token', 'pos_embed']:
return 0
elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers
================================================
FILE: linear_util/lr_sched.py
================================================
import math
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr
================================================
FILE: linear_util/misc.py
================================================
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import torch
import torch.distributed as dist
from torch._six import inf
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, exp_name=None):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
if exp_name is not None:
checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))]
else:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
if exp_name is not None:
model.save_checkpoint(save_dir=args.output_dir, tag="{}_checkpoint-{}".format(exp_name, epoch_name), client_state=client_state)
else:
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
checkpoint_model = checkpoint['model']
model_without_ddp.load_state_dict(checkpoint_model)
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
================================================
FILE: linear_util/pos_embed.py
================================================
import numpy as np
import torch
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
================================================
FILE: models/modeling_cae.py
================================================
import math
import time
import torch
import torch.nn as nn
from functools import partial
from models.modeling_finetune import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
from models.modeling_cae_helper import *
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class VisionTransformerForMaskedImageModeling(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs):
super().__init__()
self.encoder = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
vocab_size=vocab_size, embed_dim=embed_dim, depth=depth,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim,
use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args)
# alignment constraint
self.teacher = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
vocab_size=vocab_size, embed_dim=embed_dim, depth=depth,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim,
use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args)
self.init_std = init_std
self.args = args
self.num_patches = self.encoder.patch_embed.num_patches
self.pretext_neck = VisionTransformerNeck(patch_size=patch_size, num_classes=args.decoder_num_classes, embed_dim=args.decoder_embed_dim, depth=args.regressor_depth,
num_heads=args.decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate, norm_layer=norm_layer, init_values=args.decoder_layer_scale_init_value, num_patches=self.num_patches, init_std=init_std, args=args)
# encoder to decoder projection, borrowed from mae.
if args.decoder_embed_dim != embed_dim:
self.encoder_to_decoder = nn.Linear(embed_dim, args.decoder_embed_dim, bias=True)
self.encoder_to_decoder_norm = norm_layer(args.decoder_embed_dim)
else:
self.encoder_to_decoder = None
self.mask_token = nn.Parameter(torch.zeros(1, 1, args.decoder_embed_dim))
trunc_normal_(self.mask_token, std=self.init_std)
### whether to use 'rescale' to init the weight, borrowed from beit.
if not args.fix_init_weight:
self.apply(self._init_weights)
self._init_teacher()
def _init_teacher(self):
# init the weights of teacher with those of backbone
for param_encoder, param_teacher in zip(self.encoder.parameters(), self.teacher.parameters()):
param_teacher.detach()
param_teacher.data.copy_(param_encoder.data)
param_teacher.requires_grad = False
def momentum_update(self, base_momentum=0):
"""Momentum update of the teacher network."""
for param_encoder, param_teacher in zip(self.encoder.parameters(),
self.teacher.parameters()):
param_teacher.data = param_teacher.data * base_momentum + \
param_encoder.data * (1. - base_momentum)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
Input shape:
x: [bs, 3, 224, 224]
bool_masked_pos: [bs, num_patch * num_patch]
'''
def forward(self, x, bool_masked_pos, return_all_tokens=None):
batch_size = x.size(0)
'''
Encoder
Output shape:
[bs, num_visible + 1, C]
'''
x_unmasked = self.encoder(x, bool_masked_pos=bool_masked_pos)
# encoder to decoder projection
if self.encoder_to_decoder is not None:
x_unmasked = self.encoder_to_decoder(x_unmasked)
x_unmasked = self.encoder_to_decoder_norm(x_unmasked)
'''
Alignment constraint
'''
with torch.no_grad():
latent_target = self.teacher(x, bool_masked_pos=(~bool_masked_pos))
latent_target = latent_target[:, 1:, :] # remove class token
if self.encoder_to_decoder is not None:
latent_target = self.encoder_to_decoder_norm(self.encoder_to_decoder(latent_target.detach()))
self.momentum_update(self.args.base_momentum)
'''
Latent contextual regressor and decoder
'''
b, num_visible_plus1, dim = x_unmasked.shape
# remove class token
x_unmasked = x_unmasked[:, 1:, :]
num_masked_patches = self.num_patches - (num_visible_plus1-1)
# generate position embeddings.
pos_embed = self.encoder.build_2d_sincos_position_embedding(dim, use_cls_token=True).expand(batch_size, self.num_patches+1, dim).cuda(x_unmasked.device)
# pos embed for masked patches
pos_embed_masked = pos_embed[:,1:][bool_masked_pos].reshape(batch_size, -1, dim)
# pos embed for unmasked patches
pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim)
# masked embedding '''
x_masked = self.mask_token.expand(batch_size, num_masked_patches, -1)
logits, latent_pred = self.pretext_neck(x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos)
logits = logits.view(-1, logits.shape[2])
return logits, latent_pred, latent_target
@register_model
def cae_small_patch16_224_8k_vocab(pretrained=False, **kwargs):
model = VisionTransformerForMaskedImageModeling(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def cae_base_patch16_224_8k_vocab(pretrained=False, **kwargs):
model = VisionTransformerForMaskedImageModeling(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def cae_large_patch16_224_8k_vocab(pretrained=False, **kwargs):
model = VisionTransformerForMaskedImageModeling(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model
================================================
FILE: models/modeling_cae_helper.py
================================================
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from models.modeling_finetune import PatchEmbed, DropPath, Mlp
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, bool_masked_pos=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
'''
Modified from Attention()
'''
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, bool_masked_pos=None, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, bool_masked_pos=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), bool_masked_pos))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), bool_masked_pos))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class RegressorBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.norm2_cross = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_cross = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1_cross = nn.Parameter(torch.ones((dim)),requires_grad=False)
self.gamma_2_cross = nn.Parameter(torch.ones((dim)),requires_grad=False)
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos):
x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn(self.norm1_q(x_q + pos_q),
bool_masked_pos, k=self.norm1_k(x_kv + pos_k), v=self.norm1_v(x_kv)))
x = self.norm2_cross(x)
x = x + self.drop_path(self.gamma_2_cross * self.mlp_cross(x))
return x
'''
Encoder that extracts representations
'''
class VisionTransformerEncoder(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.num_patches = num_patches
# generate class token and pos embed
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim, use_cls_token=True)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=None,
attn_head_dim=attn_head_dim,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.init_std = init_std
# init the model
trunc_normal_(self.cls_token, std=self.init_std)
self.apply(self._init_weights)
# rescale init function from beit
# if it is not activated, it will be overwritten
self.fix_init_weight()
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., use_cls_token=False):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
if not use_cls_token:
pos_embed = nn.Parameter(pos_emb)
else:
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.init_std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=self.init_std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_num_layers(self):
return len(self.blocks)
def forward_features(self, x, bool_masked_pos):
x = self.patch_embed(x, bool_masked_pos=bool_masked_pos)
batch_size, seq_len, dim = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# unmasked embeddings
x_unmasked = x[~bool_masked_pos].reshape(batch_size, -1, dim)
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
if self.pos_embed is not None:
pos_embed = self.pos_embed.expand(batch_size, self.num_patches+1, dim)
pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim)
pos_embed_unmasked = torch.cat((pos_embed[:,:1], pos_embed_unmasked),dim=1)
x_unmasked = x_unmasked + pos_embed_unmasked
x_unmasked = self.pos_drop(x_unmasked)
for blk in self.blocks:
x_unmasked = blk(x_unmasked, bool_masked_pos)
x_unmasked = self.norm(x_unmasked)
return x_unmasked
def forward(self, x, bool_masked_pos, return_all_tokens=False):
x = self.forward_features(x, bool_masked_pos=bool_masked_pos)
return x
'''
Latent context regressor + decoder that solves the pretext task.
'''
class VisionTransformerNeck(nn.Module):
def __init__(self, patch_size=16, num_classes=8192, embed_dim=768, depth=6,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=None, init_values=None, num_patches=196, init_std=0.02, args=None, patch_shape=(14,14)):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
self.args = args
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# context regressor
self.regressor_blocks = nn.ModuleList([
RegressorBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, args.decoder_depth)]
self.decoder_blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(args.decoder_depth)])
self.norm = norm_layer(embed_dim)
self.norm2 = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_std = init_std
# init the model
trunc_normal_(self.head.weight, std=self.init_std)
self.apply(self._init_weights)
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.regressor_blocks):
rescale(layer.cross_attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp_cross.fc2.weight.data, layer_id + 1)
for layer_id, layer in enumerate(self.decoder_blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.init_std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=self.init_std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos):
# latent contextual regressor
for blk in self.regressor_blocks:
x_masked = blk(x_masked, torch.cat([x_unmasked, x_masked], dim=1), pos_embed_masked, torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1), bool_masked_pos)
x_masked = self.norm(x_masked)
latent_pred = x_masked
x_masked = x_masked + pos_embed_masked # add pos embed, like encoder
for blk in self.decoder_blocks:
x_masked = blk(x_masked)
x_masked = self.norm2(x_masked)
logits = self.head(x_masked)
return logits, latent_pred
================================================
FILE: models/modeling_discrete_vae.py
================================================
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on OpenAI DALL-E and lucidrains' DALLE-pytorch code bases
# https://github.com/openai/DALL-E
# https://github.com/lucidrains/DALLE-pytorch
# --------------------------------------------------------'
from math import sqrt
import os
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
class BasicVAE(nn.Module):
def get_codebook_indices(self, images):
raise NotImplementedError()
def decode(self, img_seq):
raise NotImplementedError()
def get_codebook_probs(self, img_seq):
raise NotImplementedError()
def get_image_tokens_size(self):
pass
def get_image_size(self):
pass
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding=1),
nn.ReLU(),
nn.Conv2d(chan, chan, 3, padding=1),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class DiscreteVAE(BasicVAE):
def __init__(
self,
image_size = 256,
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 2,
hidden_dim = 64,
channels = 3,
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
kl_div_loss_weight = 0.,
):
super().__init__()
# assert log2(image_size).is_integer(), 'image size must be a power of 2'
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0
self.image_size = image_size
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(num_tokens, codebook_dim)
hdim = hidden_dim
enc_chans = [hidden_dim] * num_layers
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_layers = []
dec_layers = []
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
if num_resnet_blocks > 0:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight
def get_image_size(self):
return self.image_size
def get_image_tokens_size(self):
return self.image_size // 8
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
logits = self.forward(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
@torch.no_grad()
@eval_decorator
def get_codebook_probs(self, images, temp):
logits = self.forward(images, return_logits = True)
return nn.Softmax(dim=1)(logits / temp)
def decode(
self,
img_seq
):
image_embeds = self.codebook(img_seq)
b, n, d = image_embeds.shape
h = w = int(sqrt(n))
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
images = self.decoder(image_embeds)
return images
def forward(
self,
img,
return_loss = False,
return_recons = False,
return_logits = False,
temp = None
):
device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'
logits = self.encoder(img)
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training
temp = default(temp, self.temperature)
soft_one_hot = F.gumbel_softmax(logits.float(), tau = temp, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight).type_as(logits)
out = self.decoder(sampled)
if not return_loss:
return out
# reconstruction loss
recon_loss = self.loss_fn(img, out)
# kl divergence
logits = rearrange(logits, 'b n h w -> b (h w) n')
_C = logits.size(-1)
avg_probs = F.softmax(logits.contiguous().view(-1, _C), dim=-1, dtype=torch.float32).mean(0)
diversity_loss = torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=-1).mean()
if not return_recons:
return recon_loss, diversity_loss
return recon_loss, diversity_loss, out
from dall_e import load_model
class Dalle_VAE(BasicVAE):
def __init__(self, image_size):
super().__init__()
self.encoder = None
self.decoder = None
self.image_size = image_size
def load_model(self, model_dir, device):
self.encoder = load_model(os.path.join(model_dir, "encoder.pkl"), device)
self.decoder = load_model(os.path.join(model_dir, "decoder.pkl"), device)
def decode(self, img_seq):
bsz = img_seq.size()[0]
img_seq = img_seq.view(bsz, self.image_size // 8, self.image_size // 8)
z = F.one_hot(img_seq, num_classes=self.encoder.vocab_size).permute(0, 3, 1, 2).float()
return self.decoder(z).float()
def get_codebook_indices(self, images):
z_logits = self.encoder(images)
return torch.argmax(z_logits, axis=1)
def get_codebook_probs(self, images):
z_logits = self.encoder(images)
return nn.Softmax(dim=1)(z_logits)
def forward(self, img_seq_prob, no_process=False):
if no_process:
return self.decoder(img_seq_prob.float()).float()
else:
bsz, seq_len, num_class = img_seq_prob.size()
z = img_seq_prob.view(bsz, self.image_size // 8, self.image_size // 8, self.encoder.vocab_size)
return self.decoder(z.permute(0, 3, 1, 2).float()).float()
class VGGAN(BasicVAE):
def __init__(self, image_size):
super().__init__()
self.encoder = None
self.decoder = None
self.image_size = image_size
def load_model(self, weight_path, device):
self.vqgan = torch.load(weight_path, map_location=device)
def get_codebook_indices(self, images):
_, _, [_, _, indices] = self.vqgan.encode(images) # indices: [b, h//8, w//8]
return indices
================================================
FILE: models/modeling_finetune.py
================================================
import math
import numpy as np
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from furnace.utils import LP_BatchNorm
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, bool_masked_pos=None, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.norm2_cross = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if token:
sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0)
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001, lin_probe=False, linear_type='standard', args=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_mean_pooling = use_mean_pooling
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
elif args.sin_pos_emb:
# sine-cosine positional embeddings is on the way
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.lin_probe = lin_probe
self.linear_type = linear_type
if lin_probe:
if self.linear_type == 'standard':
self.fc_norm = None
elif self.linear_type == 'attentive':
self.query_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.attentive_blocks = nn.ModuleList([
AttentiveBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer,
init_values=0)
for i in range(1)])
self.fc_norm = LP_BatchNorm(embed_dim, affine=False)
else:
if use_mean_pooling:
self.fc_norm = norm_layer(embed_dim)
else:
self.fc_norm = None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None and use_abs_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000.):
h, w = self.patch_embed.patch_shape
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
# if self.use_mean_pooling:
# pos_embed = nn.Parameter(pos_emb)
# else:
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
pos_embed.requires_grad = False
return pos_embed
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, is_train=True):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()
else:
x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
x = self.norm(x)
# linear probing or attentive probing
if self.lin_probe:
if self.linear_type == 'standard':
return x[:, 0]
else:
query_tokens = self.query_token.expand(batch_size, -1, -1)
for blk in self.attentive_blocks:
query_tokens = blk(query_tokens, x, 0, 0, bool_masked_pos=None, rel_pos_bias=None)
return self.fc_norm(query_tokens[:, 0, :], is_train=is_train)
else: # finetune
if self.fc_norm is not None: # use mean pooling
t = x[:, 1:, :]
return self.fc_norm(t.mean(1))
else:
return x[:, 0]
def forward(self, x, is_train=True):
x = self.forward_features(x, is_train)
x = self.head(x)
return x
@register_model
def cae_small_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def cae_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def cae_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def cae_large_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def cae_large_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def cae_large_patch16_512(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
================================================
FILE: requirements.txt
================================================
torch==1.7.1
torchvision==0.8.2
timm==0.3.2
Pillow
blobfile
mypy
numpy
pytest
requests
einops
tensorboardX
deepspeed==0.4.0
scipy
pytorch-lightning==1.0.8
omegaconf==2.0.0
================================================
FILE: scripts/cae_base_800e.sh
================================================
tmp_my_name=${0##*/}
my_name=${tmp_my_name%.*}
OUTPUT_DIR='./output/'$my_name
DATA_PATH=/path/to/imagenet1k/train
TOKENIZER_PATH=./tokenizer-weights
ADDRESS=ADDR_FOR_THIS_MACHINE
NNODES=4
RANK=RANK_FOR_THIS_MACHINE
# ============================ pretraining ============================
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_pretraining.py \
--data_path ${DATA_PATH} \
--output_dir ${OUTPUT_DIR} \
--model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \
--batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \
--clip_grad 3.0 --layer_scale_init_value 0.1 \
--imagenet_default_mean_and_std \
--color_jitter 0 \
--drop_path 0.1 \
--sincos_pos_emb \
--mask_generator block \
--num_mask_patches 98 \
--decoder_layer_scale_init_value 0.1 \
--no_auto_resume \
--save_ckpt_freq 100 \
--exp_name $my_name \
--regressor_depth 4 \
--decoder_depth 4 \
--align_loss_weight 2
# ============================ linear probing ============================
DATA_PATH=/path/to/imagenet1k/
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_linear.py \
--model cae_base_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 \
--batch_size 512 \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--dist_eval --data_path ${DATA_PATH} \
--output_dir $OUTPUT_DIR \
--log_dir $OUTPUT_DIR \
--enable_linear_eval \
--use_cls \
--dist_eval \
--save_freq 50 \
--disable_rel_pos_bias \
--linear_type standard \
--exp_name $my_name
# ============================ attentive probing ============================
DATA_PATH=/path/to/imagenet1k/
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_attentive.py \
--model cae_base_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \
--output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \
--warmup_epochs 10 --epochs 90 \
--weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \
--color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \
--opt sgd --momentum 0.9 \
--enable_linear_eval \
--use_cls \
--dist_eval \
--no_auto_resume \
--save_ckpt_freq 50 \
--linear_type attentive \
--exp_name $my_name
================================================
FILE: scripts/cae_base_finetune.sh
================================================
tmp_my_name=${0##*/}
my_name=${tmp_my_name%.*}
OUTPUT_DIR='./output/'$my_name
DATA_PATH=/path/to/imagenet1k/train
TOKENIZER_PATH=./tokenizer-weights
ADDRESS=ADDR_FOR_THIS_MACHINE
NNODES=4
RANK=RANK_FOR_THIS_MACHINE
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_class_finetuning.py \
--model cae_base_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 --data_set IMNET \
--output_dir $OUTPUT_DIR \
--batch_size 128 \
--lr 8e-3 --update_freq 1 \
--warmup_epochs 5 --epochs 100 --layer_decay 0.65 --drop_path 0.1 \
--weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \
--sin_pos_emb \
--dist_eval \
--no_auto_resume \
--exp_name $my_name \
--imagenet_default_mean_and_std
================================================
FILE: scripts/cae_large_1600e.sh
================================================
tmp_my_name=${0##*/}
my_name=${tmp_my_name%.*}
OUTPUT_DIR='./output/'$my_name
DATA_PATH=/path/to/imagenet1k/train
TOKENIZER_PATH=./tokenizer-weights
ADDRESS=ADDR_FOR_THIS_MACHINE
NNODES=4
RANK=RANK_FOR_THIS_MACHINE
# ============================ pretraining ============================
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_pretraining.py \
--data_path ${DATA_PATH} \
--output_dir ${OUTPUT_DIR} \
--model cae_large_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \
--batch_size 64 --lr 1.5e-3 --warmup_epochs 40 --epochs 1600 \
--clip_grad 3.0 --layer_scale_init_value 1e-5 \
--imagenet_default_mean_and_std \
--color_jitter 0 \
--drop_path 0.1 \
--sincos_pos_emb \
--mask_generator block \
--num_mask_patches 98 \
--decoder_layer_scale_init_value 1e-5 \
--no_auto_resume \
--save_ckpt_freq 100 \
--exp_name $my_name \
--regressor_depth 4 \
--decoder_depth 4 \
--align_loss_weight 2
--decoder_embed_dim 1024 \
--decoder_num_heads 16 \
--fix_init_weight
# ============================ linear probing ============================
DATA_PATH=/path/to/imagenet1k/
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_linear.py \
--model cae_large_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 \
--batch_size 512 \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--dist_eval --data_path ${DATA_PATH} \
--output_dir $OUTPUT_DIR \
--log_dir $OUTPUT_DIR \
--enable_linear_eval \
--use_cls \
--dist_eval \
--save_freq 90 \
--disable_rel_pos_bias \
--linear_type standard \
--exp_name $my_name
# ============================ attentive probing ============================
DATA_PATH=/path/to/imagenet1k/
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_attentive.py \
--model cae_large_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \
--output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \
--warmup_epochs 10 --epochs 90 \
--weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \
--color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \
--opt sgd --momentum 0.9 \
--enable_linear_eval \
--use_cls \
--dist_eval \
--no_auto_resume \
--save_ckpt_freq 50 \
--linear_type attentive \
--exp_name $my_name
================================================
FILE: scripts/cae_large_finetune.sh
================================================
tmp_my_name=${0##*/}
my_name=${tmp_my_name%.*}
OUTPUT_DIR='./output/'$my_name
DATA_PATH=/path/to/imagenet1k/train
TOKENIZER_PATH=./tokenizer-weights
ADDRESS=ADDR_FOR_THIS_MACHINE
NNODES=4
RANK=RANK_FOR_THIS_MACHINE
MODEL_PATH=/path/to/pretrained/model
OMP_NUM_THREADS=1 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=$NNODES \
--node_rank=$RANK \
--master_addr=$ADDRESS \
--master_port=8899 \
tools/run_class_finetuning.py \
--model cae_large_patch16_224 --data_path $DATA_PATH \
--finetune $MODEL_PATH \
--nb_classes 1000 --data_set IMNET \
--output_dir $OUTPUT_DIR \
--batch_size 64 \
--lr 2e-3 --update_freq 2 \
--warmup_epochs 5 --epochs 50 --layer_decay 0.75 --drop_path 0.2 \
--weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \
--sin_pos_emb \
--dist_eval \
--no_auto_resume \
--exp_name $my_name \
--imagenet_default_mean_and_std
================================================
FILE: tokenizer-weights/README
================================================
-- tokenizers
================================================
FILE: tools/run_attentive.py
================================================
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.data.mixup import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import ModelEma
from furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
from furnace.datasets import build_dataset
from furnace.engine_for_finetuning import train_one_epoch, evaluate
from furnace.utils import NativeScalerWithGradNormCount as NativeScaler
import furnace.utils as utils
from scipy import interpolate
import models.modeling_finetune
def get_args():
parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=30, type=int)
parser.add_argument('--update_freq', default=1, type=int)
parser.add_argument('--save_ckpt_freq', default=5, type=int)
# Model parameters
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--rel_pos_bias', action='store_true')
parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
parser.set_defaults(rel_pos_bias=True)
parser.add_argument('--abs_pos_emb', action='store_true')
parser.set_defaults(abs_pos_emb=False)
parser.add_argument('--sin_pos_emb', action='store_true')
parser.set_defaults(sin_pos_emb=True)
parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')
parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
help='Attention dropout rate (default: 0.)')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
parser.add_argument('--model_ema', action='store_true', default=False)
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
weight decay. We use a cosine schedule for WD and using a larger decay by
the end of training improves performance for ViTs.""")
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--layer_decay', type=float, default=0.9)
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
# Evaluation parameters
parser.add_argument('--crop_pct', type=float, default=None)
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--model_key', default='model|module|state_dict', type=str)
parser.add_argument('--model_prefix', default='', type=str)
parser.add_argument('--init_scale', default=0.001, type=float)
parser.add_argument('--use_mean_pooling', action='store_true')
parser.set_defaults(use_mean_pooling=True)
parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--eval_data_path', default=None, type=str,
help='dataset path for evaluation')
parser.add_argument('--nb_classes', default=0, type=int,
help='number of the classification types')
parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'],
type=str, help='ImageNet dataset path')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default=None,
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--save_ckpt', action='store_true')
parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
parser.set_defaults(save_ckpt=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--enable_deepspeed', action='store_true', default=False)
parser.add_argument('--enable_linear_eval', action='store_true', default=False)
parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something')
parser.add_argument('--linear_type', default='standard', type=str, help='standard, attentive')
parser.add_argument('--exp_name', default='', type=str,
help='name of exp. it is helpful when save the checkpoint')
known_args, _ = parser.parse_known_args()
if known_args.enable_deepspeed:
try:
import deepspeed
from deepspeed import DeepSpeedConfig
parser = deepspeed.add_config_arguments(parser)
ds_init = deepspeed.initialize
except:
print("Please 'pip install deepspeed==0.4.0'")
exit(0)
else:
ds_init = None
return parser.parse_args(), ds_init
def main(args, ds_init):
if not args.enable_linear_eval:
args.aa = 'rand-m9-mstd0.5-inc1'
utils.init_distributed_mode(args)
if ds_init is not None:
utils.create_ds_config(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
if args.disable_eval_during_finetuning:
dataset_val = None
else:
dataset_val, _ = build_dataset(is_train=False, args=args)
if True: # args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if dataset_val is not None:
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1.5 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
else:
data_loader_val = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
model = create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
attn_drop_rate=args.attn_drop_rate,
drop_block_rate=None,
use_mean_pooling=args.use_mean_pooling,
init_scale=args.init_scale,
use_rel_pos_bias=args.rel_pos_bias,
use_abs_pos_emb=args.abs_pos_emb,
init_values=args.layer_scale_init_value,
lin_probe=args.enable_linear_eval,
linear_type=args.linear_type,
args=args,
)
if args.enable_linear_eval:
linear_keyword = 'head'
head_norm = 'fc_norm'
parameters_requires_grad = []
for name, param in model.named_parameters():
param.requires_grad = False # no grad by default
if 'gamma' in name:
param.requires_grad = False
else:
if ('query_token' in name) or ('attentive_blocks' in name) or (name in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]) or (head_norm in name):
parameters_requires_grad.append(name)
param.requires_grad = True
print(f'parameters that need grad: ', parameters_requires_grad)
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
patch_size = model.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
args.patch_size = patch_size
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load ckpt from %s" % args.finetune)
checkpoint_model = None
for model_key in args.model_key.split('|'):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key)
break
if checkpoint_model is None:
checkpoint_model = checkpoint
state_dict = model.state_dict()
original_all_keys = list(checkpoint_model.keys())
print("##########origin keys:", len(original_all_keys), original_all_keys)
# NOTE: remove all decoder keys
all_keys = [key for key in original_all_keys if key.startswith('encoder.')]
print("all keys:", all_keys)
for key in all_keys:
new_key = key.replace('encoder.','')
# print("new_key:", new_key)
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
# handle moco-v3 checkpoints
all_keys = [key for key in original_all_keys if key.startswith('module.base_encoder.')]
print("all keys:", all_keys)
for key in all_keys:
new_key = key.replace('module.base_encoder.','')
# print("new_key:", new_key)
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
for key in list(checkpoint_model.keys()):
if key.startswith('decoder.'):
# print("key:", key)
checkpoint_model.pop(key)
if key.startswith('teacher.'):
# print("key:", key)
checkpoint_model.pop(key)
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
print("Expand the shared relative position embedding to each transformer block. ")
num_layers = model.get_num_layers()
rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key and args.rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
checkpoint_model[key] = new_rel_pos_bias
print("##############new keys:", len(checkpoint_model), checkpoint_model.keys())
#print("##############model:", model)
# interpolate position embedding
if 'pos_embed' in checkpoint_model and args.abs_pos_emb:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
# model.load_state_dict(checkpoint_model, strict=False)
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
print("Using EMA with decay = %.8f" % args.model_ema_decay)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params:', n_parameters)
total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
num_training_steps_per_epoch = len(dataset_train) // total_batch_size
print("LR = %.8f" % args.lr)
print("Batch size = %d" % total_batch_size)
print("Update frequent = %d" % args.update_freq)
print("Number of training examples = %d" % len(dataset_train))
print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
num_layers = model_without_ddp.get_num_layers()
if args.layer_decay < 1.0:
assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
else:
assigner = None
if assigner is not None:
print("Assigned values = %s" % str(assigner.values))
skip_weight_decay_list = model.no_weight_decay()
print("Skip weight decay list: ", skip_weight_decay_list)
if args.disable_weight_decay_on_rel_pos_bias:
for i in range(num_layers):
skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i)
if args.enable_deepspeed:
loss_scaler = None
optimizer_params = get_parameter_groups(
model, args.weight_decay, skip_weight_decay_list,
assigner.get_layer_id if assigner is not None else None,
assigner.get_scale if assigner is not None else None)
model, optimizer, _, _ = ds_init(
args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
)
print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
assert model.gradient_accumulation_steps() == args.update_freq
else:
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
optimizer = create_optimizer(
args, model_without_ddp, skip_list=skip_weight_decay_list,
get_num_layer=assigner.get_layer_id if assigner is not None else None,
get_layer_scale=assigner.get_scale if assigner is not None else None)
loss_scaler = NativeScaler()
print("Use step level LR scheduler!")
lr_schedule_values = utils.cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
if args.weight_decay_end is None:
args.weight_decay_end = args.weight_decay
wd_schedule_values = utils.cosine_scheduler(
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
if mixup_fn is not None:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
utils.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp,
optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if log_writer is not None:
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer,
device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
)
if args.output_dir and args.save_ckpt:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name, model_ema=model_ema)
if data_loader_val is not None:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if max_accuracy < test_stats["acc1"]:
max_accuracy = test_stats["acc1"]
if args.output_dir and args.save_ckpt:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
print(f'Max accuracy: {max_accuracy:.2f}%')
if log_writer is not None:
log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
# **{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts, ds_init = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts, ds_init)
================================================
FILE: tools/run_class_finetuning.py
================================================
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
import shutil
from pathlib import Path
from timm.data.mixup import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import ModelEma
from furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
from furnace.datasets import build_dataset
from furnace.engine_for_finetuning import train_one_epoch, evaluate
from furnace.utils import NativeScalerWithGradNormCount as NativeScaler
import furnace.utils as utils
from scipy import interpolate
import models.modeling_finetune
def get_args():
parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=30, type=int)
parser.add_argument('--update_freq', default=1, type=int)
parser.add_argument('--save_ckpt_freq', default=5, type=int)
# Model parameters
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--rel_pos_bias', action='store_true')
parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
parser.set_defaults(rel_pos_bias=True)
parser.add_argument('--abs_pos_emb', action='store_true')
parser.set_defaults(abs_pos_emb=False)
parser.add_argument('--sin_pos_emb', action='store_true')
parser.set_defaults(sin_pos_emb=True)
parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')
parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
help='Attention dropout rate (default: 0.)')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
parser.add_argument('--model_ema', action='store_true', default=False)
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
weight decay. We use a cosine schedule for WD and using a larger decay by
the end of training improves performance for ViTs.""")
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--layer_decay', type=float, default=0.9)
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
# Evaluation parameters
parser.add_argument('--crop_pct', type=float, default=None)
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--model_key', default='model|module|state_dict', type=str)
parser.add_argument('--model_prefix', default='', type=str)
parser.add_argument('--init_scale', default=0.001, type=float)
parser.add_argument('--use_mean_pooling', action='store_true')
parser.set_defaults(use_mean_pooling=True)
parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--eval_data_path', default=None, type=str,
help='dataset path for evaluation')
parser.add_argument('--nb_classes', default=0, type=int,
help='number of the classification types')
parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'],
type=str, help='ImageNet dataset path')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default=None,
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--save_ckpt', action='store_true')
parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
parser.set_defaults(save_ckpt=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--enable_deepspeed', action='store_true', default=False)
parser.add_argument('--enable_linear_eval', action='store_true', default=False)
parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something')
parser.add_argument('--exp_name', default='', type=str,
help='name of exp. it is helpful when save the checkpoint')
known_args, _ = parser.parse_known_args()
if known_args.enable_deepspeed:
try:
import deepspeed
from deepspeed import DeepSpeedConfig
parser = deepspeed.add_config_arguments(parser)
ds_init = deepspeed.initialize
except:
print("Please 'pip install deepspeed==0.4.0'")
exit(0)
else:
ds_init = None
return parser.parse_args(), ds_init
def main(args, ds_init):
if not args.enable_linear_eval:
args.aa = 'rand-m9-mstd0.5-inc1'
utils.init_distributed_mode(args)
if ds_init is not None:
utils.create_ds_config(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
if args.disable_eval_during_finetuning:
dataset_val = None
else:
dataset_val, _ = build_dataset(is_train=False, args=args)
if True: # args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if dataset_val is not None:
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1.5 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
else:
data_loader_val = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
model = create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
attn_drop_rate=args.attn_drop_rate,
drop_block_rate=None,
use_mean_pooling=args.use_mean_pooling,
init_scale=args.init_scale,
use_rel_pos_bias=args.rel_pos_bias,
use_abs_pos_emb=args.abs_pos_emb,
init_values=args.layer_scale_init_value,
lin_probe=args.enable_linear_eval,
args=args,
)
if args.enable_linear_eval:
# freeze all layers but the last fc
linear_keyword = 'head'
head_norm = 'fc_norm'
requires_grad = []
for name, param in model.named_parameters():
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword] and head_norm not in name:
param.requires_grad = False
else:
requires_grad.append(name)
print(f'require grad parameter: ', requires_grad)
# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
patch_size = model.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
args.patch_size = patch_size
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load ckpt from %s" % args.finetune)
checkpoint_model = None
for model_key in args.model_key.split('|'):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key)
break
if checkpoint_model is None:
checkpoint_model = checkpoint
state_dict = model.state_dict()
all_keys = list(checkpoint_model.keys())
print("##########origin keys:", len(all_keys), all_keys)
# NOTE: remove all decoder keys
all_keys = [key for key in all_keys if key.startswith('encoder.')]
print("all keys:", all_keys)
for key in all_keys:
new_key = key.replace('encoder.','')
# print("new_key:", new_key)
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
for key in list(checkpoint_model.keys()):
if key.startswith('decoder.'):
# print("key:", key)
checkpoint_model.pop(key)
if key.startswith('teacher.'):
# print("key:", key)
checkpoint_model.pop(key)
# NOTE: replace norm with fc_norm
for key in list(checkpoint_model.keys()):
# print("new key:", key)
if key.startswith('norm.'):
new_key = key.replace('norm.','fc_norm.')
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
print("Expand the shared relative position embedding to each transformer block. ")
num_layers = model.get_num_layers()
rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key and args.rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
checkpoint_model[key] = new_rel_pos_bias
print("##############new keys:", len(checkpoint_model), checkpoint_model.keys())
#print("##############model:", model)
# interpolate position embedding
if 'pos_embed' in checkpoint_model and args.abs_pos_emb:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
# model.load_state_dict(checkpoint_model, strict=False)
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
print("Using EMA with decay = %.8f" % args.model_ema_decay)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params:', n_parameters)
total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
num_training_steps_per_epoch = len(dataset_train) // total_batch_size
print("LR = %.8f" % args.lr)
print("Batch size = %d" % total_batch_size)
print("Update frequent = %d" % args.update_freq)
print("Number of training examples = %d" % len(dataset_train))
print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
num_layers = model_without_ddp.get_num_layers()
if args.layer_decay < 1.0:
assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
else:
assigner = None
if assigner is not None:
print("Assigned values = %s" % str(assigner.values))
skip_weight_decay_list = model.no_weight_decay()
print("Skip weight decay list: ", skip_weight_decay_list)
if args.disable_weight_decay_on_rel_pos_bias:
for i in range(num_layers):
skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i)
if args.enable_deepspeed:
loss_scaler = None
optimizer_params = get_parameter_groups(
model, args.weight_decay, skip_weight_decay_list,
assigner.get_layer_id if assigner is not None else None,
assigner.get_scale if assigner is not None else None)
model, optimizer, _, _ = ds_init(
args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
)
print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
assert model.gradient_accumulation_steps() == args.update_freq
else:
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
optimizer = create_optimizer(
args, model_without_ddp, skip_list=skip_weight_decay_list,
get_num_layer=assigner.get_layer_id if assigner is not None else None,
get_layer_scale=assigner.get_scale if assigner is not None else None)
loss_scaler = NativeScaler()
print("Use step level LR scheduler!")
lr_schedule_values = utils.cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
if args.weight_decay_end is None:
args.weight_decay_end = args.weight_decay
wd_schedule_values = utils.cosine_scheduler(
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
if mixup_fn is not None:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
utils.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp,
optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if log_writer is not None:
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer,
device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
)
if args.output_dir and args.save_ckpt:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
if data_loader_val is not None:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if max_accuracy < test_stats["acc1"]:
max_accuracy = test_stats["acc1"]
if args.output_dir and args.save_ckpt:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
print(f'Max accuracy: {max_accuracy:.2f}%')
if log_writer is not None:
log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
# **{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts, ds_init = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts, ds_init)
================================================
FILE: tools/run_linear.py
================================================
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
assert timm.__version__ == "0.3.2" # version check
from timm.models.layers import trunc_normal_
import linear_util.misc as misc
from linear_util.pos_embed import interpolate_pos_embed
from linear_util.misc import NativeScalerWithGradNormCount as NativeScaler
from linear_util.lars import LARS
from linear_util.crop import RandomResizedCrop
import models.modeling_finetune as models_vit
from linear_util.engine_finetune import train_one_epoch, evaluate
def setup_for_distributed(rank):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
if rank==0:
builtin_print(*args, **kwargs)
__builtin__.print = print
def get_args_parser():
parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
parser.add_argument('--batch_size', default=512, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=90, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay (default: 0 for linear probe following MoCo v1)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
help='Attention dropout rate (default: 0.)')
parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--init_scale', default=0.001, type=float)
parser.add_argument('--use_mean_pooling', action='store_true')
parser.set_defaults(use_mean_pooling=True)
parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
parser.add_argument('--rel_pos_bias', action='store_true')
parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
parser.set_defaults(rel_pos_bias=True)
parser.add_argument('--abs_pos_emb', action='store_true')
parser.set_defaults(abs_pos_emb=False)
parser.add_argument('--sin_pos_emb', action='store_true')
parser.set_defaults(sin_pos_emb=True)
parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb')
parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
parser.add_argument('--enable_linear_eval', action='store_true', default=False)
parser.add_argument('--exp_name', default='', type=str,
help='name of exp. it is helpful when save the checkpoint')
parser.add_argument('--save_freq', default=50, type=int,
help='freq of saving models')
parser.add_argument('--linear_type', default='standard', type=str,
help='standard or attentive')
return parser
def main(args):
misc.init_distributed_mode(args)
setup_for_distributed(args.local_rank)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# linear probe: weak augmentation
transform_train = transforms.Compose([
RandomResizedCrop(224, interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_val = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
print(dataset_train)
print(dataset_val)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
# log_writer = SummaryWriter(log_dir=args.log_dir)
log_writer = None
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
attn_drop_rate=args.attn_drop_rate,
use_mean_pooling=args.use_mean_pooling,
init_scale=args.init_scale,
use_rel_pos_bias=args.rel_pos_bias,
use_abs_pos_emb=args.abs_pos_emb,
init_values=args.layer_scale_init_value,
lin_probe=args.enable_linear_eval,
args=args,
)
if args.finetune and not args.eval:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
for key in list(checkpoint_model.keys()):
if 'encoder.' in key:
new_key = key.replace('encoder.','')
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
if 'teacher' in key or 'decoder' in key:
checkpoint_model.pop(key)
if args.rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
print("Expand the shared relative position embedding to each transformer block. ")
num_layers = model.get_num_layers()
rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key and args.rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
checkpoint_model[key] = new_rel_pos_bias
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
trunc_normal_(model.head.weight, std=0.01)
model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
requires_grad = []
for _, p in model.named_parameters():
p.requires_grad = False
for nname, p in model.head.named_parameters():
p.requires_grad = True
requires_grad.append(nname)
print(f'require grad parameter: ', requires_grad)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)
print(optimizer)
loss_scaler = NativeScaler()
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
max_norm=None,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % args.save_freq == 0 or epoch + 1 == args.epochs):
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
if log_writer is not None:
log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
================================================
FILE: tools/run_pretraining.py
================================================
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
import shutil
from pathlib import Path
from timm.models import create_model
from furnace.optim_factory import create_optimizer
from furnace.datasets import build_cae_pretraining_dataset
from furnace.engine_for_pretraining import train_one_epoch
from furnace.utils import NativeScalerWithGradNormCount as NativeScaler
import furnace.utils as utils
from models import modeling_cae
import torch.distributed as dist
def get_args():
parser = argparse.ArgumentParser('pre-training script', add_help=False)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--save_ckpt_freq', default=50, type=int)
parser.add_argument("--discrete_vae_weight_path", type=str)
parser.add_argument("--discrete_vae_type", type=str, default="dall-e", help='[dall-e, vqgan_gumbel_f8_8192, customized]')
parser.add_argument('--dvae_num_layers', default=3, type=int)
# Model parameters
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--rel_pos_bias', action='store_true', default=False)
parser.add_argument('--abs_pos_emb', action='store_true', default=False)
parser.add_argument('--sincos_pos_emb', action='store_true', default=False)
parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
parser.add_argument('--input_size', default=224, type=int,
help='images input size for backbone')
parser.add_argument('--second_input_size', default=112, type=int,
help='images input size for discrete vae')
parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
help='Drop path rate (default: 0)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: 0.9, 0.98, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
weight decay. We use a cosine schedule for WD.
(Set the same value with args.weight_decay to keep weight decay no change)""")
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
# Augmentation parameters
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--second_interpolation', type=str, default='lanczos',
help='Interpolation for discrete vae (random, bilinear, bicubic default: "lanczos")')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default=None,
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--exp_name', default='', type=str, help='it is used when save the checkpoint')
parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu to print something')
'''
Data augmentation
'''
# crop size
parser.add_argument('--crop_min_size', type=float, default=0.08, help='min size of crop')
parser.add_argument('--crop_max_size', type=float, default=1.0, help='max size of crop')
# color jitter
parser.add_argument('--color_jitter', type=float, default=0, metavar='PCT', help='Color jitter factor (default: 0)')
'''
Mask strategy
'''
parser.add_argument('--mask_generator', default='block', type=str,
help='block or random')
# 1. if use block mask, set the num_mask_patches
parser.add_argument('--num_mask_patches', default=98, type=int,
help='number of the visual tokens/patches need be masked')
parser.add_argument('--max_mask_patches_per_block', type=int, default=None)
parser.add_argument('--min_mask_patches_per_block', type=int, default=16)
# 2. if use random mask, set the mask ratio
parser.add_argument('--ratio_mask_patches', default=None, type=float, help="mask ratio")
'''
CAE hyper-parameters
'''
parser.add_argument('--regressor_depth', default=4, type=int, help='depth of the regressor')
parser.add_argument('--decoder_depth', default=4, type=int, help='depth of the decoder')
parser.add_argument('--decoder_embed_dim', default=768, type=int,
help='dimensionaltiy of embeddings for decoder')
parser.add_argument('--decoder_num_heads', default=12, type=int,
help='Number of heads for decoder')
parser.add_argument('--decoder_num_classes', default=8192, type=int,
help='Number of classes for decoder')
parser.add_argument('--decoder_layer_scale_init_value', default=0.1, type=float,
help='decoder layer scale init value')
# alignment constraint
parser.add_argument('--align_loss_weight', type=float, default=2, help='loss weight for the alignment constraint')
parser.add_argument('--base_momentum', type=float, default=0, help='ema weight for the dual path network')
# init func, borrowed from BEiT
parser.add_argument('--fix_init_weight', action='store_true', default=False, help='if true, the fix_init_weight() func will be activated')
return parser.parse_args()
def get_model(args):
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
use_abs_pos_emb=args.abs_pos_emb,
init_values=args.layer_scale_init_value,
args=args,
)
return model
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
model = get_model(args)
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
args.patch_size = patch_size
# get dataset
dataset_train = build_cae_pretraining_dataset(args)
# prepare discrete vae
d_vae = utils.create_d_vae(
weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type,
device=device, image_size=args.second_input_size, args=args)
if True: # args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
sampler_rank = global_rank
num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params:', n_parameters)
total_batch_size = args.batch_size * utils.get_world_size()
print("LR = %.8f" % args.lr)
print("Batch size = %d" % total_batch_size)
print("Number of training steps = %d" % num_training_steps_per_epoch)
print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
optimizer = create_optimizer(
args, model_without_ddp)
loss_scaler = NativeScaler()
print("Use step level LR & WD scheduler!")
lr_schedule_values = utils.cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
if args.weight_decay_end is None:
args.weight_decay_end = args.weight_decay
wd_schedule_values = utils.cosine_scheduler(
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
utils.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if log_writer is not None:
log_writer.set_step(epoch * num_training_steps_per_epoch)
train_stats = train_one_epoch(
model, d_vae, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, log_writer=log_writer,
start_steps=epoch * num_training_steps_per_epoch,
lr_schedule_values=lr_schedule_values,
wd_schedule_values=wd_schedule_values,
args=args,
)
if args.output_dir:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts)