Repository: lxtGH/CAE Branch: master Commit: d72597143e48 Files: 236 Total size: 1.2 MB Directory structure: gitextract_wj4p2u9n/ ├── .gitignore ├── .gitignore.swp ├── README.md ├── dall_e/ │ ├── __init__.py │ ├── decoder.py │ ├── encoder.py │ └── utils.py ├── downstream_tasks/ │ ├── detection/ │ │ ├── README.md │ │ ├── evaluation/ │ │ │ └── object_detection/ │ │ │ ├── configs/ │ │ │ │ ├── _base_/ │ │ │ │ │ ├── datasets/ │ │ │ │ │ │ └── coco_instance.py │ │ │ │ │ ├── default_runtime.py │ │ │ │ │ ├── models/ │ │ │ │ │ │ ├── cascade_mask_rcnn_r50_fpn.py │ │ │ │ │ │ ├── cascade_mask_rcnn_swin_fpn.py │ │ │ │ │ │ ├── cascade_mask_rcnn_vit_fpn.py │ │ │ │ │ │ ├── mask_rcnn_r50_fpn.py │ │ │ │ │ │ └── mask_rcnn_vit_fpn.py │ │ │ │ │ └── schedules/ │ │ │ │ │ └── schedule_1x.py │ │ │ │ └── mask_rcnn/ │ │ │ │ ├── vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py │ │ │ │ └── vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py │ │ │ ├── mmcv_custom/ │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ ├── layer_decay_optimizer_constructor.py │ │ │ │ ├── prepare_rpe.py │ │ │ │ ├── register_backbone.py │ │ │ │ └── runner/ │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ └── epoch_based_runner.py │ │ │ ├── test.py │ │ │ └── train.py │ │ ├── loader.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── head.py │ │ │ ├── swin_transformer.py │ │ │ └── vision_transformer.py │ │ ├── scripts/ │ │ │ ├── run_eval.sh │ │ │ ├── run_train_maskrcnn_vit_base.sh │ │ │ └── run_train_maskrcnn_vit_large.sh │ │ └── utils.py │ └── semantic_segmentation/ │ ├── README.md │ ├── backbone/ │ │ ├── beit.py │ │ ├── beit_fapn.py │ │ ├── cae.py │ │ ├── fapn.py │ │ └── mae.py │ ├── configs_local/ │ │ ├── _base_/ │ │ │ ├── datasets/ │ │ │ │ ├── ade20k.py │ │ │ │ ├── ade20k_640x640.py │ │ │ │ ├── chase_db1.py │ │ │ │ ├── cityscapes.py │ │ │ │ ├── cityscapes_769x769.py │ │ │ │ ├── coco-stuff10k.py │ │ │ │ ├── drive.py │ │ │ │ ├── hrf.py │ │ │ │ ├── pascal_context.py │ │ │ │ ├── pascal_voc12.py │ │ │ │ ├── pascal_voc12_aug.py │ │ │ │ └── stare.py │ │ │ ├── default_runtime.py │ │ │ ├── models/ │ │ │ │ ├── ann_r50-d8.py │ │ │ │ ├── apcnet_r50-d8.py │ │ │ │ ├── ccnet_r50-d8.py │ │ │ │ ├── cgnet.py │ │ │ │ ├── danet_r50-d8.py │ │ │ │ ├── deeplabv3_r50-d8.py │ │ │ │ ├── deeplabv3_unet_s5-d16.py │ │ │ │ ├── deeplabv3plus_r50-d8.py │ │ │ │ ├── dmnet_r50-d8.py │ │ │ │ ├── dnl_r50-d8.py │ │ │ │ ├── emanet_r50-d8.py │ │ │ │ ├── encnet_r50-d8.py │ │ │ │ ├── fast_scnn.py │ │ │ │ ├── fcn_hr18.py │ │ │ │ ├── fcn_r50-d8.py │ │ │ │ ├── fcn_unet_s5-d16.py │ │ │ │ ├── fpn_r50.py │ │ │ │ ├── gcnet_r50-d8.py │ │ │ │ ├── lraspp_m-v3-d8.py │ │ │ │ ├── nonlocal_r50-d8.py │ │ │ │ ├── ocrnet_hr18.py │ │ │ │ ├── ocrnet_r50-d8.py │ │ │ │ ├── pointrend_r50.py │ │ │ │ ├── psanet_r50-d8.py │ │ │ │ ├── pspnet_r50-d8.py │ │ │ │ ├── pspnet_unet_s5-d16.py │ │ │ │ ├── upernet_cae.py │ │ │ │ └── upernet_r50.py │ │ │ └── schedules/ │ │ │ ├── schedule_160k.py │ │ │ ├── schedule_20k.py │ │ │ ├── schedule_320k.py │ │ │ ├── schedule_40k.py │ │ │ └── schedule_80k.py │ │ ├── beit/ │ │ │ └── upernet_beit_base_12_512_slide_160k_ade20k_pt_4e-4.py │ │ ├── cae/ │ │ │ └── upernet/ │ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py │ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py │ │ │ ├── upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py │ │ │ └── upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py │ │ └── mae/ │ │ └── upernet_mae_large_12_512_slide_160k_ade20k_pt_4e-4.py │ ├── mmcv_custom/ │ │ ├── __init__.py │ │ ├── apex_runner/ │ │ │ ├── __init__.py │ │ │ ├── apex_iter_based_runner.py │ │ │ ├── checkpoint.py │ │ │ └── optimizer.py │ │ ├── checkpoint.py │ │ ├── checkpoint_beit.py │ │ ├── layer_decay_optimizer_constructor.py │ │ ├── resize_transform.py │ │ └── train_api.py │ ├── mmseg/ │ │ ├── __init__.py │ │ ├── apis/ │ │ │ ├── __init__.py │ │ │ ├── inference.py │ │ │ ├── test.py │ │ │ └── train.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── evaluation/ │ │ │ │ ├── __init__.py │ │ │ │ ├── class_names.py │ │ │ │ ├── eval_hooks.py │ │ │ │ └── metrics.py │ │ │ ├── seg/ │ │ │ │ ├── __init__.py │ │ │ │ ├── builder.py │ │ │ │ └── sampler/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_pixel_sampler.py │ │ │ │ └── ohem_pixel_sampler.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ └── misc.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── ade.py │ │ │ ├── builder.py │ │ │ ├── chase_db1.py │ │ │ ├── cityscapes.py │ │ │ ├── coco_stuff.py │ │ │ ├── custom.py │ │ │ ├── dataset_wrappers.py │ │ │ ├── drive.py │ │ │ ├── hrf.py │ │ │ ├── pascal_context.py │ │ │ ├── pipelines/ │ │ │ │ ├── __init__.py │ │ │ │ ├── compose.py │ │ │ │ ├── formating.py │ │ │ │ ├── loading.py │ │ │ │ ├── test_time_aug.py │ │ │ │ └── transforms.py │ │ │ ├── stare.py │ │ │ └── voc.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── backbones/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cgnet.py │ │ │ │ ├── fast_scnn.py │ │ │ │ ├── hrnet.py │ │ │ │ ├── mobilenet_v2.py │ │ │ │ ├── mobilenet_v3.py │ │ │ │ ├── resnest.py │ │ │ │ ├── resnet.py │ │ │ │ ├── resnext.py │ │ │ │ └── unet.py │ │ │ ├── builder.py │ │ │ ├── decode_heads/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ann_head.py │ │ │ │ ├── apc_head.py │ │ │ │ ├── aspp_head.py │ │ │ │ ├── cascade_decode_head.py │ │ │ │ ├── cc_head.py │ │ │ │ ├── da_head.py │ │ │ │ ├── decode_head.py │ │ │ │ ├── dm_head.py │ │ │ │ ├── dnl_head.py │ │ │ │ ├── ema_head.py │ │ │ │ ├── enc_head.py │ │ │ │ ├── fcn_head.py │ │ │ │ ├── fpn_head.py │ │ │ │ ├── gc_head.py │ │ │ │ ├── lraspp_head.py │ │ │ │ ├── nl_head.py │ │ │ │ ├── ocr_head.py │ │ │ │ ├── point_head.py │ │ │ │ ├── psa_head.py │ │ │ │ ├── psp_head.py │ │ │ │ ├── sep_aspp_head.py │ │ │ │ ├── sep_fcn_head.py │ │ │ │ └── uper_head.py │ │ │ ├── losses/ │ │ │ │ ├── __init__.py │ │ │ │ ├── accuracy.py │ │ │ │ ├── cross_entropy_loss.py │ │ │ │ ├── lovasz_loss.py │ │ │ │ └── utils.py │ │ │ ├── necks/ │ │ │ │ ├── __init__.py │ │ │ │ └── fpn.py │ │ │ ├── segmentors/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── cascade_encoder_decoder.py │ │ │ │ └── encoder_decoder.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── inverted_residual.py │ │ │ ├── make_divisible.py │ │ │ ├── res_layer.py │ │ │ ├── se_layer.py │ │ │ ├── self_attention_block.py │ │ │ └── up_conv_block.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ ├── encoding.py │ │ │ └── wrappers.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── collect_env.py │ │ │ └── logger.py │ │ └── version.py │ └── tools/ │ ├── dist_test.sh │ ├── dist_train.sh │ ├── test.py │ └── train.py ├── furnace/ │ ├── dataset_folder.py │ ├── datasets.py │ ├── engine_for_finetuning.py │ ├── engine_for_pretraining.py │ ├── masking_generator.py │ ├── optim_factory.py │ ├── transforms.py │ └── utils.py ├── linear_util/ │ ├── crop.py │ ├── datasets.py │ ├── engine_finetune.py │ ├── lars.py │ ├── lr_decay.py │ ├── lr_sched.py │ ├── misc.py │ └── pos_embed.py ├── models/ │ ├── modeling_cae.py │ ├── modeling_cae_helper.py │ ├── modeling_discrete_vae.py │ └── modeling_finetune.py ├── requirements.txt ├── scripts/ │ ├── cae_base_800e.sh │ ├── cae_base_finetune.sh │ ├── cae_large_1600e.sh │ └── cae_large_finetune.sh ├── tokenizer-weights/ │ └── README └── tools/ ├── run_attentive.py ├── run_class_finetuning.py ├── run_linear.py └── run_pretraining.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .DS_Store ================================================ FILE: README.md ================================================ # CAE: Context AutoEncoder for Self-Supervised Representation Learning

This is a PyTorch implementation of [CAE: Context AutoEncoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026). ## Highlights - State-of-the-art MIM performance. Results in the paper are successfully reproduced. ## Installation Clone the repo and install required packages. ```bash pip install -r requirements.txt # install apex git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ ``` ## Data Preparation First, download ImageNet-1k from http://image-net.org/. The directory structure is the standard layout of torchvision's datasets.ImageFolder. The training and validation data are expected to be in the train/ folder and val folder, respectively: ``` /path/to/imagenet/ train/ class1/ img1.jpeg class2/ img2.jpeg val/ class1/ img3.jpeg class/2 img4.jpeg ``` Second, download the pretrained tokenizer. ```bash TOKENIZER_PATH=/path/to/save/dall_e_tokenizer_weight mkdir -p $TOKENIZER_PATH wget -o $TOKENIZER_PATH/encoder.pkl https://cdn.openai.com/dall-e/encoder.pkl wget -o $TOKENIZER_PATH/decoder.pkl https://cdn.openai.com/dall-e/decoder.pkl ``` ## Pretraining Here is an example that pretrains CAE-base on ImageNet-1K with 32 GPUs. Please see [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) for complete script. ```bash OMP_NUM_THREADS=1 $PYTHON -m torch.distributed.launch \ --nproc_per_node=8 \ tools/run_pretraining.py \ --data_path ${DATA_PATH} \ --output_dir ${OUTPUT_DIR} \ --model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \ --batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \ --clip_grad 3.0 --layer_scale_init_value 0.1 \ --imagenet_default_mean_and_std \ --color_jitter 0 \ --drop_path 0.1 \ --sincos_pos_emb \ --mask_generator block \ --num_mask_patches 98 \ --decoder_layer_scale_init_value 0.1 \ --no_auto_resume \ --save_ckpt_freq 100 \ --exp_name $my_name \ --regressor_depth 4 \ --decoder_depth 4 \ --align_loss_weight 2 ``` - `--num_mask_patches`: number of the input patches need be masked. - `--batch_size`: batch size per GPU. - Effective batch size = `number of GPUs` * `--batch_size`. So in the above example, the effective batch size is `64*32 = 2048`. - `--lr`: learning rate. - `--warmup_epochs`: learning rate warmup epochs. Warm up [10, 20, 40] epochs for [300, 800, 1600] pretrain epochs respectively. - `--epochs`: total pretraining epochs. - `--clip_grad`: clip gradient norm. - `--drop_path`: stochastic depth rate. - `--imagenet_default_mean_and_std`: enable this for ImageNet-1k pretraining, i.e., `(0.485, 0.456, 0.406)` for mean and `(0.229, 0.224, 0.225)` for std. For other pretraining data, use `(0.5, 0.5, 0.5)` for mean and `(0.5, 0.5, 0.5)` for std by default. - `--layer_scale_init_value`: 0.1 for base, 1e-5 for large, set 0 to disable layerscale. We set `--decoder_layer_scale_init_value` the same as this. - `--sincos_pos_emb`: adopt sin-cos positional embedding during pretraining. - `--regressor_depth`: length of the regressor. - `--decoder_depth`: length of the decoder. - `--align_loss_weight`: weight for alignment loss. 2 by default. Warmup epochs for 300/800/1600 epochs pretraining are 10/20/40. For CAE-large, please refer to [scripts/cae_large_1600e.sh](scripts/cae_large_1600e.sh). ## Results Here provides the results of CAE-base/CAE-large for these evaluation tasks: - Linear probing - Attentive probing - Fine-tuning - Semantic segmentation - Object detection and instance segmentation Pretrained weights and logs are available ([Google Drive](https://drive.google.com/drive/folders/1wwhg7nj2GQuU9uthVuQLkEEXEjx90G7g?usp=sharing), [Baidu Cloud [Code: 4kil]](https://pan.baidu.com/s/15eZGoI72iLupLrOHqmOM9w)). *: from CAE paper. | Model | Pretraining data | #Epoch | Linear | Attentive | Fine-tuning | ADE Seg | COCO Det | COCO InstSeg | | ---------- | ---------------- | ------ | ------ | --------- | ----------- | ------- | -------- | ------------ | | MAE-base* | ImageNet-1K | 1600 | 67.8 | 74.2 | 83.6 | 48.1 | 48.4 | 42.6 | | MAE-large* | ImageNet-1K | 1600 | 76.0 | 78.8 | 86.0 | 53.6 | 54.0 | 47.1 | | CAE-base | ImageNet-1K | 300 | 64.5 | 74.0 | 83.6 | 48.1 | 48.3 | 42.7 | | CAE-base | ImageNet-1K | 800 | 68.9 | 75.9 | 83.8 | 49.7 | 49.9 | 43.9 | | CAE-base | ImageNet-1K | 1600 | 70.3 | 77.2 | 83.9 | 50.3 | 50.3 | 44.2 | | CAE-large | ImageNet-1K | 1600 | 77.8 | 81.2 | 86.2 | 54.9 | 54.5 | 47.5 | ### Linear Probing - Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs). - For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`. ### Attentive Probing - Please refer to [scripts/cae_base_800e.sh](scripts/cae_base_800e.sh) (32 GPUs). - For CAE-large, just replace `--model cae_base_patch16_224` with `--model cae_large_patch16_224`. ### Fine-tuning - Please refer to [scripts/cae_base_finetune.sh](scripts/cae_base_finetune.sh) (32 GPUs). - For CAE-large, please refer to [scripts/cae_large_finetune.sh](scripts/cae_large_finetune.sh) (32 GPUs). ### Segmentation & Detection - Please refer to [downstream_tasks](./downstream_tasks) dir to get started. ## Acknowledgement This repository is built using the [BEiT](https://github.com/microsoft/unilm/edit/master/beit) and [MMSelfSup](https://github.com/open-mmlab/mmselfsup), thanks for their open-source code! Thanks also to the CAE authors for their excellent work! ## Citation ```bibtex @article{ContextAutoencoder2022, title={Context Autoencoder for Self-Supervised Representation Learning}, author={Chen, Xiaokang and Ding, Mingyu and Wang, Xiaodi and Xin, Ying and Mo, Shentong and Wang, Yunhao and Han, Shumin and Luo, Ping and Zeng, Gang and Wang, Jingdong}, journal={arXiv preprint arXiv:2202.03026}, year={2022} } ``` ================================================ FILE: dall_e/__init__.py ================================================ import io, requests import torch import torch.nn as nn from dall_e.encoder import Encoder from dall_e.decoder import Decoder from dall_e.utils import map_pixels, unmap_pixels def load_model(path: str, device: torch.device = None) -> nn.Module: if path.startswith('http://') or path.startswith('https://'): resp = requests.get(path) resp.raise_for_status() with io.BytesIO(resp.content) as buf: return torch.load(buf, map_location=device) else: with open(path, 'rb') as f: return torch.load(f, map_location=device) ================================================ FILE: dall_e/decoder.py ================================================ import attr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict from functools import partial from dall_e.utils import Conv2d @attr.s(eq=False, repr=False) class DecoderBlock(nn.Module): n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) device: torch.device = attr.ib(default=None) requires_grad: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: super().__init__() self.n_hid = self.n_out // 4 self.post_gain = 1 / (self.n_layers ** 2) make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() self.res_path = nn.Sequential(OrderedDict([ ('relu_1', nn.ReLU()), ('conv_1', make_conv(self.n_in, self.n_hid, 1)), ('relu_2', nn.ReLU()), ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), ('relu_3', nn.ReLU()), ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), ('relu_4', nn.ReLU()), ('conv_4', make_conv(self.n_hid, self.n_out, 3)),])) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.id_path(x) + self.post_gain * self.res_path(x) @attr.s(eq=False, repr=False) class Decoder(nn.Module): group_count: int = 4 n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8) n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) device: torch.device = attr.ib(default=torch.device('cpu')) requires_grad: bool = attr.ib(default=False) use_mixed_precision: bool = attr.ib(default=True) def __attrs_post_init__(self) -> None: super().__init__() blk_range = range(self.n_blk_per_group) n_layers = self.group_count * self.n_blk_per_group make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device, requires_grad=self.requires_grad) self.blocks = nn.Sequential(OrderedDict([ ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)), ('group_1', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), ]))), ('group_2', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), ]))), ('group_3', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), ]))), ('group_4', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], ]))), ('output', nn.Sequential(OrderedDict([ ('relu', nn.ReLU()), ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)), ]))), ])) def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 4: raise ValueError(f'input shape {x.shape} is not 4d') if x.shape[1] != self.vocab_size: raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}') if x.dtype != torch.float32: raise ValueError('input must have dtype torch.float32') return self.blocks(x) ================================================ FILE: dall_e/encoder.py ================================================ import attr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict from functools import partial from dall_e.utils import Conv2d @attr.s(eq=False, repr=False) class EncoderBlock(nn.Module): n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) device: torch.device = attr.ib(default=None) requires_grad: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: super().__init__() self.n_hid = self.n_out // 4 self.post_gain = 1 / (self.n_layers ** 2) make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() self.res_path = nn.Sequential(OrderedDict([ ('relu_1', nn.ReLU()), ('conv_1', make_conv(self.n_in, self.n_hid, 3)), ('relu_2', nn.ReLU()), ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), ('relu_3', nn.ReLU()), ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), ('relu_4', nn.ReLU()), ('conv_4', make_conv(self.n_hid, self.n_out, 1)),])) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.id_path(x) + self.post_gain * self.res_path(x) @attr.s(eq=False, repr=False) class Encoder(nn.Module): group_count: int = 4 n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) device: torch.device = attr.ib(default=torch.device('cpu')) requires_grad: bool = attr.ib(default=False) use_mixed_precision: bool = attr.ib(default=True) def __attrs_post_init__(self) -> None: super().__init__() blk_range = range(self.n_blk_per_group) n_layers = self.group_count * self.n_blk_per_group make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device, requires_grad=self.requires_grad) self.blocks = nn.Sequential(OrderedDict([ ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), ('group_1', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_2', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_3', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_4', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], ]))), ('output', nn.Sequential(OrderedDict([ ('relu', nn.ReLU()), ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), ]))), ])) def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 4: raise ValueError(f'input shape {x.shape} is not 4d') if x.shape[1] != self.input_channels: raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') if x.dtype != torch.float32: raise ValueError('input must have dtype torch.float32') return self.blocks(x) ================================================ FILE: dall_e/utils.py ================================================ import attr import math import torch import torch.nn as nn import torch.nn.functional as F logit_laplace_eps: float = 0.1 @attr.s(eq=False) class Conv2d(nn.Module): n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) use_float16: bool = attr.ib(default=True) device: torch.device = attr.ib(default=torch.device('cpu')) requires_grad: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: super().__init__() w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad) w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad) self.w, self.b = nn.Parameter(w), nn.Parameter(b) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_float16 and 'cuda' in self.w.device.type: if x.dtype != torch.float16: x = x.half() w, b = self.w.half(), self.b.half() else: if x.dtype != torch.float32: x = x.float() w, b = self.w, self.b return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) def map_pixels(x: torch.Tensor) -> torch.Tensor: if x.dtype != torch.float: raise ValueError('expected input to have type float') return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps def unmap_pixels(x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 4: raise ValueError('expected input to be 4d') if x.dtype != torch.float: raise ValueError('expected input to have type float') return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) ================================================ FILE: downstream_tasks/detection/README.md ================================================ # COCO Detection and Instance segmentation with CAE # Installation Please install [PyTorch](https://pytorch.org/). This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. To get the full dependencies, please run: ```bash pip3 install -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.1/index.html mmcv-full==1.3.9 pip3 install pytest-runner scipy tensorboardX faiss-gpu==1.6.1 tqdm lmdb sklearn pyarrow==2.0.0 timm DALL-E munkres six einops # install apex pip3 install git+https://github.com/NVIDIA/apex \ --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" # install mmdetection for object detection & instance segmentation git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection cd Swin-Transformer-Object-Detection pip3 install -r requirements/build.txt pip3 install -v -e . cd .. ``` ## Fine-tuning with Mask R-CNN #### We use 16 GPUs for these experiments, $NNODES = 2. - To train ViT-B/16 with Mask R-CNN as the task layer, run: ```bash python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=$PORT \ evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \ --launcher pytorch \ --work-dir $OUTPUT_DIR \ --no-validate \ --deterministic \ --cfg-options model.backbone.use_checkpoint=True \ model.pretrained=$PRETRAINED \ ${@:6} ``` - To train ViT-L/16 with Mask R-CNN as the task layer, run: ```bash python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=$PORT \ evaluation/object_detection/train.py evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \ --launcher pytorch \ --work-dir $OUTPUT_DIR \ --no-validate \ --deterministic \ --cfg-options model.backbone.use_checkpoint=True \ model.pretrained=$PRETRAINED \ ${@:6} ``` - To evaluate Mask R-CNN, run: ```bash python -m torch.distributed.launch --nproc_per_node=8 \ evaluation/object_detection/test.py \ $CONFIG \ $MODEL \ --launcher pytorch \ --eval bbox segm \ --cfg-options model.backbone.use_checkpoint=True \ ${@:6} ``` ## Results (pretrined models are trained on ImageNet-1K without label) | Backbone | #Pretrained Epoch | Object Det | Instance Seg | | -------- | ----------------- | ---------- | ------------ | | ViT-B | 300 | 48.3 | 42.7 | | ViT-B | 800 | 49.9 | 43.9 | | ViT-B | 1600 | 50.3 | 44.2 | | ViT-L | 1600 | 54.5 | 47.5 | ## Acknowledgement This repository is built using the [IBOT repository](https://github.com/bytedance/ibot). Thanks for their open-source code! ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/datasets/coco_instance.py ================================================ dataset_type = 'CocoDataset' data_root = '/path/to/coco/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(1333, 800), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_train2017.json', img_prefix=data_root + 'train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline)) evaluation = dict(metric=['bbox', 'segm']) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/default_runtime.py ================================================ checkpoint_config = dict(interval=1) # yapf:disable log_config = dict( interval=50, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') ]) # yapf:enable custom_hooks = [dict(type='NumClassCheckHook')] dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None workflow = [('train', 1)] ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py ================================================ ettings model = dict( type='CascadeRCNN', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), roi_head=dict( type='CascadeRoIHead', num_stages=3, stage_loss_weights=[1, 0.5, 0.25], bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=[ dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.05, 0.05, 0.1, 0.1]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.033, 0.033, 0.067, 0.067]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) ], mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='FCNMaskHead', num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), # model training and testing settings train_cfg=dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, debug=False), rpn_proposal=dict( nms_pre=2000, max_per_img=2000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=[ dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False) ]), test_cfg=dict( rpn=dict( nms_pre=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5))) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_swin_fpn.py ================================================ # model settings model = dict( type='CascadeRCNN', pretrained=None, backbone=dict( type='SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), use_checkpoint=False), neck=dict( type='FPN', in_channels=[96, 192, 384, 768], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), roi_head=dict( type='CascadeRoIHead', num_stages=3, stage_loss_weights=[1, 0.5, 0.25], bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=[ dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.05, 0.05, 0.1, 0.1]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.033, 0.033, 0.067, 0.067]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) ], mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='FCNMaskHead', num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), # model training and testing settings train_cfg = dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, debug=False), rpn_proposal=dict( nms_across_levels=False, nms_pre=2000, nms_post=2000, max_per_img=2000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=[ dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False) ]), test_cfg = dict( rpn=dict( nms_across_levels=False, nms_pre=1000, nms_post=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5))) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/cascade_mask_rcnn_vit_fpn.py ================================================ # model settings model = dict( type='CascadeRCNN', pretrained=None, backbone=dict( type='VisionTransformer', img_size=[672, 1092], patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, drop_path_rate=0.1, out_indices=(3, 5, 7, 11), use_checkpoint=False), neck=dict( type='FPN', in_channels=[384, 384, 384, 384], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), roi_head=dict( type='CascadeRoIHead', num_stages=3, stage_loss_weights=[1, 0.5, 0.25], bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=[ dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.05, 0.05, 0.1, 0.1]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.033, 0.033, 0.067, 0.067]), reg_class_agnostic=True, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) ], mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='FCNMaskHead', num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), # model training and testing settings train_cfg = dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=0, pos_weight=-1, debug=False), rpn_proposal=dict( nms_across_levels=False, nms_pre=2000, nms_post=2000, max_per_img=2000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=[ dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.6, neg_iou_thr=0.6, min_pos_iou=0.6, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False), dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.7, min_pos_iou=0.7, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False) ]), test_cfg = dict( rpn=dict( nms_across_levels=False, nms_pre=1000, nms_post=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5))) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_r50_fpn.py ================================================ # model settings model = dict( type='MaskRCNN', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0)), roi_head=dict( type='StandardRoIHead', bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0)), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='FCNMaskHead', num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), # model training and testing settings train_cfg=dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=-1, pos_weight=-1, debug=False), rpn_proposal=dict( nms_pre=2000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)), test_cfg=dict( rpn=dict( nms_pre=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5))) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/models/mask_rcnn_vit_fpn.py ================================================ # model settings model = dict( type='MaskRCNN', pretrained=None, backbone=dict( type='VisionTransformer', img_size=[672, 1092], patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, drop_path_rate=0.1, out_indices=(3, 5, 7, 11), use_checkpoint=False), neck=dict( type='FPN', in_channels=[384, 384, 384, 384], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0)), roi_head=dict( type='StandardRoIHead', bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0)), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='FCNMaskHead', num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), # model training and testing settings train_cfg=dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=-1, pos_weight=-1, debug=False), rpn_proposal=dict( nms_across_levels=False, nms_pre=2000, nms_post=2000, max_per_img=2000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), mask_size=28, pos_weight=-1, debug=False)), test_cfg=dict( rpn=dict( nms_across_levels=False, nms_pre=1000, nms_post=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), rcnn=dict( score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5))) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/_base_/schedules/schedule_1x.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=None) # learning policy lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=0.001, step=[8, 11]) runner = dict(type='EpochBasedRunner', max_epochs=12) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from timm, mmdet, and swin code bases https://github.com/rwightman/pytorch-image-models/tree/master/timm https://github.com/open-mmlab/mmdetection https://github.com/SwinTransformer/Swin-Transformer-Object-Detection """ _base_ = [ '../_base_/models/mask_rcnn_vit_fpn.py', '../_base_/datasets/coco_instance.py', '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' ] model = dict( backbone=dict( embed_dim=768, depth=12, num_heads=12, init_values=0.1, mlp_ratio=4., drop_path_rate=0.2, #see if 0.1 larger than vit-small is better use_abs_pos_emb=False, use_sincos_pos_emb=True, use_rel_pos_bias=False, ), neck=dict(in_channels=[768, 768, 768, 768]), roi_head=dict( bbox_head=dict( type='ConvFCBBoxHead', num_shared_convs=4, num_shared_fcs=1, in_channels=256, conv_out_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, reg_decoded_bbox=True, norm_cfg=dict(type='SyncBN', requires_grad=True), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='GIoULoss', loss_weight=10.0)))) img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) # augmentation strategy originates from DETR / Sparse RCNN train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='AutoAugment', policies=[ [ dict(type='Resize', img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), (608, 1333), (640, 1333), (672, 1333), (704, 1333), (736, 1333), (768, 1333), (800, 1333)], multiscale_mode='value', keep_ratio=True) ], [ dict(type='Resize', img_scale=[(400, 1333), (500, 1333), (600, 1333)], multiscale_mode='value', keep_ratio=True), dict(type='RandomCrop', crop_type='absolute_range', crop_size=(384, 600), allow_negative_crop=True), dict(type='Resize', img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), (608, 1333), (640, 1333), (672, 1333), (704, 1333), (736, 1333), (768, 1333), (800, 1333)], multiscale_mode='value', override=True, keep_ratio=True) ] ]), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict(pipeline=train_pipeline)) optimizer = dict(_delete_=True, type='AdamW', lr=0.0003, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.75)) lr_config = dict(step=[9, 11]) runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py ================================================ #Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from timm, mmdet, and swin code bases https://github.com/rwightman/pytorch-image-models/tree/master/timm https://github.com/open-mmlab/mmdetection https://github.com/SwinTransformer/Swin-Transformer-Object-Detection """ _base_ = [ '../_base_/models/mask_rcnn_vit_fpn.py', '../_base_/datasets/coco_instance.py', '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' ] find_unused_parameters = False model = dict( backbone=dict( embed_dim=1024, depth=24, num_heads=16, init_values=0.00001, mlp_ratio=4., drop_path_rate=0.2, #see if 0.1 larger than vit-small is better use_abs_pos_emb=False, use_sincos_pos_emb=True, use_rel_pos_bias=False, out_indices=[7, 11, 15, 23], ), neck=dict(in_channels=[1024, 1024, 1024, 1024]), roi_head=dict( bbox_head=dict( type='ConvFCBBoxHead', num_shared_convs=4, num_shared_fcs=1, in_channels=256, conv_out_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, reg_decoded_bbox=True, norm_cfg=dict(type='SyncBN', requires_grad=True), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='GIoULoss', loss_weight=10.0)))) img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) # augmentation strategy originates from DETR / Sparse RCNN train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='AutoAugment', policies=[ [ dict(type='Resize', img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), (608, 1333), (640, 1333), (672, 1333), (704, 1333), (736, 1333), (768, 1333), (800, 1333)], multiscale_mode='value', keep_ratio=True) ], [ dict(type='Resize', img_scale=[(400, 1333), (500, 1333), (600, 1333)], multiscale_mode='value', keep_ratio=True), dict(type='RandomCrop', crop_type='absolute_range', crop_size=(384, 600), allow_negative_crop=True), dict(type='Resize', img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), (608, 1333), (640, 1333), (672, 1333), (704, 1333), (736, 1333), (768, 1333), (800, 1333)], multiscale_mode='value', override=True, keep_ratio=True) ] ]), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict(pipeline=train_pipeline)) optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.85)) lr_config = dict(step=[9, 11]) runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/__init__.py ================================================ # -*- coding: utf-8 -*- from .checkpoint import load_checkpoint from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor from .register_backbone import VisionTransformer __all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor'] ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/checkpoint.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Copy-paste from mmcv library: https://github.com/open-mmlab/mmcv/ """ import io import os import os.path as osp import pkgutil import time import warnings from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory import torch import torchvision from torch.optim import Optimizer from torch.nn import functional as F import mmcv from mmcv.fileio import FileClient from mmcv.fileio import load as load_file from mmcv.parallel import is_module_wrapper from mmcv.utils import mkdir_or_exist from mmcv.runner import get_dist_info from scipy import interpolate import numpy as np import math ENV_MMCV_HOME = 'MMCV_HOME' ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' DEFAULT_CACHE_DIR = '~/.cache' def _get_mmcv_home(): mmcv_home = os.path.expanduser( os.getenv( ENV_MMCV_HOME, os.path.join( os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) mkdir_or_exist(mmcv_home) return mmcv_home def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source ' f'state_dict: {", ".join(unexpected_keys)}\n') if missing_keys: err_msg.append( f'missing keys in source state_dict: {", ".join(missing_keys)}\n') rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) def load_url_dist(url, model_dir=None, map_location="cpu"): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) return checkpoint def load_pavimodel_dist(model_path, map_location=None): """In distributed setting, this function only download checkpoint at local rank 0.""" try: from pavi import modelscloud except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load(downloaded_file, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load( downloaded_file, map_location=map_location) return checkpoint def load_fileclient_dist(filename, backend, map_location): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) allowed_backends = ['ceph'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') if rank == 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) return checkpoint def get_torchvision_models(): model_urls = dict() for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue _zoo = import_module(f'torchvision.models.{name}') if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) return model_urls def get_external_models(): mmcv_home = _get_mmcv_home() default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') default_urls = load_file(default_json_path) assert isinstance(default_urls, dict) external_json_path = osp.join(mmcv_home, 'open_mmlab.json') if osp.exists(external_json_path): external_urls = load_file(external_json_path) assert isinstance(external_urls, dict) default_urls.update(external_urls) return default_urls def get_mmcls_models(): mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') mmcls_urls = load_file(mmcls_json_path) return mmcls_urls def get_deprecated_model_names(): deprecate_json_path = osp.join(mmcv.__path__[0], 'model_zoo/deprecated.json') deprecate_urls = load_file(deprecate_json_path) assert isinstance(deprecate_urls, dict) return deprecate_urls def _process_mmcls_checkpoint(checkpoint): state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('backbone.'): new_state_dict[k[9:]] = v new_checkpoint = dict(state_dict=new_state_dict) return new_checkpoint def _load_checkpoint(filename, map_location=None): """Load checkpoint from somewhere (modelzoo, file, url). Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str | None): Same as :func:`torch.load`. Default: None. Returns: dict | OrderedDict: The loaded checkpoint. It can be either an OrderedDict storing model weights or a dict containing other information, which depends on the checkpoint. """ if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_urls = get_external_models() model_name = filename[13:] deprecated_urls = get_deprecated_model_names() if model_name in deprecated_urls: warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' f'of open-mmlab://{deprecated_urls[model_name]}') model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url if model_url.startswith(('http://', 'https://')): checkpoint = load_url_dist(model_url) else: filename = osp.join(_get_mmcv_home(), model_url) if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) elif filename.startswith('mmcls://'): model_urls = get_mmcls_models() model_name = filename[8:] checkpoint = load_url_dist(model_urls[model_name]) checkpoint = _process_mmcls_checkpoint(checkpoint) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) elif filename.startswith('pavi://'): model_path = filename[7:] checkpoint = load_pavimodel_dist(model_path, map_location=map_location) elif filename.startswith('s3://'): checkpoint = load_fileclient_dist( filename, backend='ceph', map_location=map_location) else: if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) return checkpoint def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0, warmup_steps=-1): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_steps > 0: warmup_iters = warmup_steps print("Set warmup steps = %d" % warmup_iters) if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = np.array( [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def load_checkpoint(model, filename, map_location='cpu', strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] elif 'module' in checkpoint: state_dict = checkpoint['module'] else: state_dict = checkpoint # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # for MoBY, load model of online branch if sorted(list(state_dict.keys()))[0].startswith('encoder'): state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} all_keys = list(state_dict.keys()) if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'): # NOTE: remove all decoder keys all_keys = [key for key in all_keys if key.startswith('encoder.')] for key in all_keys: new_key = key.replace('encoder.','') state_dict[new_key] = state_dict[key] state_dict.pop(key) for key in list(state_dict.keys()): if key.startswith('decoder.'): state_dict.pop(key) # NOTE: replace norm with fc_norm for key in list(state_dict.keys()): if key.startswith('norm.'): new_key = key.replace('norm.','fc_norm.') state_dict[new_key] = state_dict[key] state_dict.pop(key) # reshape absolute position embedding for Swin if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = model.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H*W: logger.warning("Error in loading absolute_pos_embed, pass") else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) rank, _ = get_dist_info() if "rel_pos_bias.relative_position_bias_table" in state_dict: if rank == 0: rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"] state_dict["relative_position_bias_table"] = rel_pos_bias state_dict.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(state_dict.keys()) for key in all_keys: if "relative_position_index" in key: state_dict.pop(key) if "relative_position_bias_table" in key and key in model.state_dict(): rel_pos_bias = state_dict[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27 dst_size_0 = dst_patch_shape[0] * 2 - 1 # 42 dst_size_1 = dst_patch_shape[1] * 2 - 1 # 68 if src_size != dst_size_0: if rank == 0: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size_0, dst_size_1)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size_0 // 2: right = q else: left = q dis_0 = [] cur = 1 for i in range(src_size // 2): dis_0.append(cur) cur += q ** (i + 1) r_ids_0 = [-_ for _ in reversed(dis_0)] top, bottom = 1.01, 1.5 while bottom - top > 1e-6: q = (top + bottom) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size_1 // 2: bottom = q else: top = q dis_1 = [] cur = 1 for i in range(src_size // 2): dis_1.append(cur) cur += q ** (i + 1) r_ids_1 = [-_ for _ in reversed(dis_1)] # if q > 1.13492: # q = 1.13492 x = r_ids_0 + [0] + dis_0 y = r_ids_1 + [0] + dis_1 t_0 = dst_size_0 // 2.0 t_1 = dst_size_1 // 2.0 dx = np.arange(-t_0, t_0 + 0.1, 1.0) dy = np.arange(-t_1, t_1 + 0.1, 1.0) if rank == 0: print("x = {}".format(x)) print("dx = {}".format(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) state_dict[key] = new_rel_pos_bias else: num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27 dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) # if src_size != dst_size: if rank == 0: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.13492: # q = 1.13492 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) if rank == 0: print("x = {}".format(x)) print("dx = {}".format(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) state_dict[key] = new_rel_pos_bias if 'pos_embed' in state_dict: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding #new_size = int(num_patches ** 0.5) new_size_w = model.patch_embed.num_patches_w new_size_h = model.patch_embed.num_patches_h # class_token and dist_token are kept unchanged if orig_size != new_size_h or orig_size != new_size_w: if rank == 0: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size_w, new_size_h)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size_w, new_size_h), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed # interpolate position bias table if needed relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k and k in model.state_dict()] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = model.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {table_key}, pass") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint def weights_to_cpu(state_dict): """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() return state_dict_cpu def _save_to_state_dict(module, destination, prefix, keep_vars): """Saves module state to `destination` dictionary. This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. Args: module (nn.Module): The module to generate state_dict. destination (dict): A dict where state will be stored. prefix (str): The prefix for parameters and buffers used in this module. """ for name, param in module._parameters.items(): if param is not None: destination[prefix + name] = param if keep_vars else param.detach() for name, buf in module._buffers.items(): # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d if buf is not None: destination[prefix + name] = buf if keep_vars else buf.detach() def get_state_dict(module, destination=None, prefix='', keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel module in case that the model has a complicated structure, e.g., nn.Module(nn.Module(DDP)). Args: module (nn.Module): The module to generate state_dict. destination (OrderedDict): Returned dict for the state of the module. prefix (str): Prefix of the key. keep_vars (bool): Whether to keep the variable property of the parameters. Default: False. Returns: dict: A dictionary containing a whole state of the module. """ # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module # below is the same as torch.nn.Module.state_dict() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict( version=module._version) _save_to_state_dict(module, destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: get_state_dict( child, destination, prefix + name + '.', keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): try: from pavi import modelscloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush() ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/layer_decay_optimizer_constructor.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from BEiT library: https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py """ import json from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor from mmcv.runner import get_dist_info def get_num_layer_for_vit(var_name, num_max_layer): if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): return 0 elif var_name.startswith("backbone.patch_embed"): return 0 elif var_name.startswith("backbone.blocks"): layer_id = int(var_name.split('.')[2]) return layer_id + 1 else: return num_max_layer - 1 @OPTIMIZER_BUILDERS.register_module() class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): def add_params(self, params, module, prefix='', is_dcn_module=None): """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. Args: params (list[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. prefix (str): The prefix of the module is_dcn_module (int|float|None): If the current module is a submodule of DCN, `is_dcn_module` will be passed to control conv_offset layer's learning rate. Defaults to None. """ parameter_groups = {} print(self.paramwise_cfg) num_layers = self.paramwise_cfg.get('num_layers') + 2 layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers)) weight_decay = self.base_wd for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'): group_name = "no_decay" this_weight_decay = 0. else: group_name = "decay" this_weight_decay = weight_decay layer_id = get_num_layer_for_vit(name, num_layers) group_name = "layer_%d_%s" % (layer_id, group_name) if group_name not in parameter_groups: scale = layer_decay_rate ** (num_layers - layer_id - 1) parameter_groups[group_name] = { "weight_decay": this_weight_decay, "params": [], "param_names": [], "lr_scale": scale, "group_name": group_name, "lr": scale * self.base_lr, } parameter_groups[group_name]["params"].append(param) parameter_groups[group_name]["param_names"].append(name) rank, _ = get_dist_info() if rank == 0: to_display = {} for key in parameter_groups: to_display[key] = { "param_names": parameter_groups[key]["param_names"], "lr_scale": parameter_groups[key]["lr_scale"], "lr": parameter_groups[key]["lr"], "weight_decay": parameter_groups[key]["weight_decay"], } print("Param groups = %s" % json.dumps(to_display, indent=2)) # state_dict = module.state_dict() # for group_name in parameter_groups: # group = parameter_groups[group_name] # for name in group["param_names"]: # group["params"].append(state_dict[name]) params.extend(parameter_groups.values()) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/prepare_rpe.py ================================================ import torch import numpy as np from scipy import interpolate from mmcv.runner import get_dist_info import torch.nn as nn def rpe_index(window_size): num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 return relative_position_index def prepare_rpe(rel_pos_bias, src_patch_shape, dst_patch_shape): src_num_pos, num_attn_heads = rel_pos_bias.size() # 732 rank, _ = get_dist_info() dst_num_pos = (dst_patch_shape[0]*2 -1) * (dst_patch_shape[1]*2 -1) + 3 if dst_patch_shape[0] != src_patch_shape[0] or dst_patch_shape[1] != src_patch_shape[1]: num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) # src_size = int((src_num_pos - num_extra_tokens) ** 0.5) # 27 src_size_0, src_size_1 = src_patch_shape[0] * 2 - 1, src_patch_shape[1]*2 -1 extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] dst_size_0 = dst_patch_shape[0] * 2 - 1 # 42 dst_size_1 = dst_patch_shape[1] * 2 - 1 # 68 dim = rel_pos_bias.shape[-1] rel_pos_bias = rel_pos_bias.reshape(1 , src_size_0, src_size_1, dim).permute(0, 3, 1, 2) new_rel_pos_bias = nn.functional.interpolate(rel_pos_bias, scale_factor=(dst_size_0 / src_size_0, dst_size_1 / dst_size_1), mode='bicubic',) new_rel_pos_bias = new_rel_pos_bias.permute(0, 2, 3, 1).view(1, -1, dim).squeeze(0) new_rel_pos_bias = torch.cat((new_rel_pos_bias, extra_tokens), dim=0) else: new_rel_pos_bias = rel_pos_bias # get rpe_index relative_position_index = rpe_index(dst_patch_shape) new_rel_pos_bias = new_rel_pos_bias[relative_position_index.view(-1)].view( dst_patch_shape[0] * dst_patch_shape[1] + 1, dst_patch_shape[0] * dst_patch_shape[1] + 1, -1) # Wh*Ww,Wh*Ww,nH new_rel_pos_bias = new_rel_pos_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return new_rel_pos_bias ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/register_backbone.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from mmcv_custom import load_checkpoint from mmdet.utils import get_root_logger from mmdet.models.builder import BACKBONES from models import VisionTransformer from .prepare_rpe import prepare_rpe import time class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.num_patches_w = img_size[0] // patch_size self.num_patches_h = img_size[1] // patch_size num_patches = self.num_patches_w * self.num_patches_h self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, mask=None): B, C, H, W = x.shape return self.proj(x) @BACKBONES.register_module() class VisionTransformer(VisionTransformer): def __init__(self, img_size, patch_size, embed_dim, in_chans=3, with_fpn=True, frozen_stages=-1, out_indices=[3, 5, 7, 11], out_with_norm=False, use_checkpoint=False, **kwargs): super(VisionTransformer, self).__init__( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, **kwargs) # support non-square image as input if len(img_size) == 1: img_size = img_size * 2 self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches if self.use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) elif self.use_sincos_pos_emb: self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim) else: self.pos_embed = None self.patch_size = patch_size self.with_fpn = with_fpn self.frozen_stages = frozen_stages self.out_indices = out_indices self.use_checkpoint = use_checkpoint if not out_with_norm: self.norm = nn.Identity() if with_fpn and patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif with_fpn and patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) else: logger = get_root_logger() logger.info('Build model without FPN.') def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(VisionTransformer, self).train(mode) self._freeze_stages() if self.pos_embed is not None: if self.pos_embed.requires_grad: print("=================pos_embed update ================") else: print("=================pos_embed static ================") def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False self.cls_token.requires_grad = False if self.pos_embed is not None and self.use_sincos_pos_emb == True: self.pos_embed.requires_grad = False self.pos_drop.eval() for i in range(1, self.frozen_stages + 1): if i == len(self.blocks): norm_layer = getattr(self, 'norm') #f'norm{i-1}') norm_layer.eval() for param in norm_layer.parameters(): param.requires_grad = False m = self.blocks[i - 1] m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): self.apply(self._init_weights) logger = get_root_logger() if os.path.isfile(pretrained): load_checkpoint(self, pretrained, strict=False, logger=logger) else: logger.info(f"checkpoint path {pretrained} is invalid, we skip it and initialize net randomly") elif pretrained is None: self.apply(self._init_weights) else: raise TypeError('pretrained must be a str or None') def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 w0 = w // self.patch_embed.patch_size h0 = h // self.patch_embed.patch_size if npatch == N and w0 == self.patch_embed.num_patches_w and h0 == self.patch_embed.num_patches_h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 tmp=patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, self.patch_embed.num_patches_w, self.patch_embed.num_patches_h, dim).permute(0, 3, 1, 2), scale_factor=(w0 / self.patch_embed.num_patches_w, h0 / self.patch_embed.num_patches_h), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward(self, x): B, _, H, W = x.shape Hp, Wp = H // self.patch_size, W // self.patch_size x = self.prepare_tokens(x) features = [] time_begin = time.time() if self.relative_position_bias_table is None: x_rpe = None else: dst_rpe_shape = (Wp, Hp) if H <= W else(Hp, Wp) x_rpe = prepare_rpe(self.relative_position_bias_table, self.patch_embed.patch_shape, dst_rpe_shape) for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, x_rpe) else: x = blk(x, x_rpe) if i in self.out_indices: xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) time_backbone = time.time() if self.with_fpn: ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) time_end = time.time() return tuple(features) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/__init__.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. from .checkpoint import save_checkpoint from .epoch_based_runner import EpochBasedRunnerAmp __all__ = [ 'EpochBasedRunnerAmp', 'save_checkpoint' ] ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/checkpoint.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Copy-paste from mmcv library: https://github.com/open-mmlab/mmcv/ """ import os.path as osp import time import torch import mmcv try: import apex except: print('apex is not installed') from tempfile import TemporaryDirectory from torch.optim import Optimizer from mmcv.parallel import is_module_wrapper from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 4 fields: ``meta``, ``state_dict`` and ``optimizer``, ``amp``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() # save amp state dict in the checkpoint checkpoint['amp'] = apex.amp.state_dict() if filename.startswith('pavi://'): try: from pavi import modelscloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush() ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/mmcv_custom/runner/epoch_based_runner.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Copy-paste from mmcv library: https://github.com/open-mmlab/mmcv/ """ import os.path as osp import platform import shutil import torch import mmcv try: import apex except: print('apex is not installed') from torch.optim import Optimizer from mmcv.runner import RUNNERS, EpochBasedRunner from .checkpoint import save_checkpoint @RUNNERS.register_module() class EpochBasedRunnerAmp(EpochBasedRunner): """Epoch-based Runner with AMP support. This runner train models epoch by epoch. """ def save_checkpoint(self, out_dir, filename_tmpl='epoch_{}.pth', save_optimizer=True, meta=None, create_symlink=True): """Save the checkpoint. Args: out_dir (str): The directory that checkpoints are saved. filename_tmpl (str, optional): The checkpoint filename template, which contains a placeholder for the epoch number. Defaults to 'epoch_{}.pth'. save_optimizer (bool, optional): Whether to save the optimizer to the checkpoint. Defaults to True. meta (dict, optional): The meta information to be saved in the checkpoint. Defaults to None. create_symlink (bool, optional): Whether to create a symlink "latest.pth" to point to the latest checkpoint. Defaults to True. """ if meta is None: meta = dict(epoch=self.epoch + 1, iter=self.iter) elif isinstance(meta, dict): meta.update(epoch=self.epoch + 1, iter=self.iter) else: raise TypeError( f'meta should be a dict or None, but got {type(meta)}') if self.meta is not None: meta.update(self.meta) filename = filename_tmpl.format(self.epoch + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False if create_symlink: dst_file = osp.join(out_dir, 'latest.pth') if platform.system() != 'Windows': mmcv.symlink(filename, dst_file) else: shutil.copy(filepath, dst_file) def resume(self, checkpoint, resume_optimizer=True, map_location='default'): if map_location == 'default': if torch.cuda.is_available(): device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}') if 'amp' in checkpoint: apex.amp.load_state_dict(checkpoint['amp']) self.logger.info('load amp state dict') self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/test.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from mmdetection library: https://github.com/open-mmlab/mmdetection/blob/master/tools/test.py """ import argparse import os import warnings import mmcv import torch from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, wrap_fp16_model) from mmdet.apis import multi_gpu_test, single_gpu_test from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.models import build_detector def parse_args(): parser = argparse.ArgumentParser( description='MMDet test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--out', help='output result file in pickle format') parser.add_argument( '--fuse-conv-bn', action='store_true', help='Whether to fuse conv and bn, this will slightly increase' 'the inference speed') parser.add_argument( '--format-only', action='store_true', help='Format the output results without perform evaluation. It is' 'useful when you want to format the result to a specific format and ' 'submit it to the test server') parser.add_argument( '--eval', type=str, nargs='+', help='evaluation metrics, which depends on the dataset, e.g., "bbox",' ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') parser.add_argument('--show', action='store_true', help='show results') parser.add_argument( '--show-dir', help='directory where painted images will be saved') parser.add_argument( '--show-score-thr', type=float, default=0.3, help='score threshold (default: 0.3)') parser.add_argument( '--gpu-collect', action='store_true', help='whether to use gpu to collect results.') parser.add_argument( '--tmpdir', help='tmp directory used for collecting results from multiple ' 'workers, available when gpu-collect is not specified') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--options', nargs='+', action=DictAction, help='custom options for evaluation, the key-value pair in xxx=yyy ' 'format will be kwargs for dataset.evaluate() function (deprecate), ' 'change to --eval-options instead.') parser.add_argument( '--eval-options', nargs='+', action=DictAction, help='custom options for evaluation, the key-value pair in xxx=yyy ' 'format will be kwargs for dataset.evaluate() function') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if args.options and args.eval_options: raise ValueError( '--options and --eval-options cannot be both ' 'specified, --options is deprecated in favor of --eval-options') if args.options: warnings.warn('--options is deprecated in favor of --eval-options') args.eval_options = args.options return args def main(): args = parse_args() assert args.out or args.eval or args.format_only or args.show \ or args.show_dir, \ ('Please specify at least one operation (save/eval/format/show the ' 'results / save the results) with the argument "--out", "--eval"' ', "--format-only", "--show" or "--show-dir"') if args.eval and args.format_only: raise ValueError('--eval and --format_only cannot be both specified') if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): raise ValueError('The output file must be a pkl file.') cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True cfg.model.pretrained = None if cfg.model.get('neck'): if isinstance(cfg.model.neck, list): for neck_cfg in cfg.model.neck: if neck_cfg.get('rfp_backbone'): if neck_cfg.rfp_backbone.get('pretrained'): neck_cfg.rfp_backbone.pretrained = None elif cfg.model.neck.get('rfp_backbone'): if cfg.model.neck.rfp_backbone.get('pretrained'): cfg.model.neck.rfp_backbone.pretrained = None # in case the test dataset is concatenated samples_per_gpu = 1 if isinstance(cfg.data.test, dict): cfg.data.test.test_mode = True samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) if samples_per_gpu > 1: # Replace 'ImageToTensor' to 'DefaultFormatBundle' cfg.data.test.pipeline = replace_ImageToTensor( cfg.data.test.pipeline) elif isinstance(cfg.data.test, list): for ds_cfg in cfg.data.test: ds_cfg.test_mode = True samples_per_gpu = max( [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) if samples_per_gpu > 1: for ds_cfg in cfg.data.test: ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # build the dataloader dataset = build_dataset(cfg.data.test) data_loader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) # build the model and load checkpoint cfg.model.train_cfg = None model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') if args.fuse_conv_bn: model = fuse_conv_bn(model) # old versions did not save class info in checkpoints, this walkaround is # for backward compatibility if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: model.CLASSES = dataset.CLASSES if not distributed: model = MMDataParallel(model, device_ids=[0]) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, args.show_score_thr) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect) rank, _ = get_dist_info() if rank == 0: if args.out: print(f'\nwriting results to {args.out}') mmcv.dump(outputs, args.out) kwargs = {} if args.eval_options is None else args.eval_options if args.format_only: dataset.format_results(outputs, **kwargs) if args.eval: eval_kwargs = cfg.get('evaluation', {}).copy() # hard-code way to remove EvalHook args for key in [ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' ]: eval_kwargs.pop(key, None) eval_kwargs.update(dict(metric=args.eval, **kwargs)) print(dataset.evaluate(outputs, **eval_kwargs)) if __name__ == '__main__': main() ================================================ FILE: downstream_tasks/detection/evaluation/object_detection/train.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from mmdetection library: https://github.com/open-mmlab/mmdetection/blob/master/tools/train.py """ import argparse import copy import os import os.path as osp import time import warnings import mmcv import torch from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist from mmcv.utils import get_git_hash from mmdet import __version__ from mmdet.apis import set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector from mmdet.utils import collect_env, get_root_logger def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') parser.add_argument( '--no-validate', action='store_true', help='whether not to evaluate the checkpoint during training') group_gpus = parser.add_mutually_exclusive_group() group_gpus.add_argument( '--gpus', type=int, help='number of gpus to use ' '(only applicable to non-distributed training)') group_gpus.add_argument( '--gpu-ids', type=int, nargs='+', help='ids of gpus to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='random seed') parser.add_argument( '--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') parser.add_argument( '--options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file (deprecate), ' 'change to --cfg-options instead.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( '--options and --cfg-options cannot be both ' 'specified, --options is deprecated in favor of --cfg-options') if args.options: warnings.warn('--options is deprecated in favor of --cfg-options') args.cfg_options = args.options return args def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info meta['config'] = cfg.pretty_text # log some basic info logger.info(f'Distributed training: {distributed}') logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds if args.seed is not None: logger.info(f'Set random seed to {args.seed}, ' f'deterministic: {args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed meta['exp_name'] = osp.basename(args.config) model = build_detector( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmdet_version=__version__ + get_git_hash()[:7], CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES train_detector( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta) if __name__ == '__main__': main() ================================================ FILE: downstream_tasks/detection/loader.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import math import numpy as np from torchvision.datasets import ImageFolder class ImageFolderInstance(ImageFolder): def __getitem__(self, index): img, target = super(ImageFolderInstance, self).__getitem__(index) return img, target, index class ImageFolderMask(ImageFolder): def __init__(self, *args, patch_size, pred_ratio, pred_ratio_var, pred_aspect_ratio, pred_shape='block', pred_start_epoch=0, **kwargs): super(ImageFolderMask, self).__init__(*args, **kwargs) self.psz = patch_size self.pred_ratio = pred_ratio[0] if isinstance(pred_ratio, list) and \ len(pred_ratio) == 1 else pred_ratio self.pred_ratio_var = pred_ratio_var[0] if isinstance(pred_ratio_var, list) and \ len(pred_ratio_var) == 1 else pred_ratio_var if isinstance(self.pred_ratio, list) and not isinstance(self.pred_ratio_var, list): self.pred_ratio_var = [self.pred_ratio_var] * len(self.pred_ratio) self.log_aspect_ratio = tuple(map(lambda x: math.log(x), pred_aspect_ratio)) self.pred_shape = pred_shape self.pred_start_epoch = pred_start_epoch def get_pred_ratio(self): if hasattr(self, 'epoch') and self.epoch < self.pred_start_epoch: return 0 if isinstance(self.pred_ratio, list): pred_ratio = [] for prm, prv in zip(self.pred_ratio, self.pred_ratio_var): assert prm >= prv pr = random.uniform(prm - prv, prm + prv) if prv > 0 else prm pred_ratio.append(pr) pred_ratio = random.choice(pred_ratio) else: assert self.pred_ratio >= self.pred_ratio_var pred_ratio = random.uniform(self.pred_ratio - self.pred_ratio_var, self.pred_ratio + \ self.pred_ratio_var) if self.pred_ratio_var > 0 else self.pred_ratio return pred_ratio def set_epoch(self, epoch): self.epoch = epoch def __getitem__(self, index): output = super(ImageFolderMask, self).__getitem__(index) masks = [] for img in output[0]: try: H, W = img.shape[1] // self.psz, img.shape[2] // self.psz except: # skip non-image continue high = self.get_pred_ratio() * H * W if self.pred_shape == 'block': # following BEiT (https://arxiv.org/abs/2106.08254), see at # https://github.com/microsoft/unilm/blob/b94ec76c36f02fb2b0bf0dcb0b8554a2185173cd/beit/masking_generator.py#L55 mask = np.zeros((H, W), dtype=bool) mask_count = 0 while mask_count < high: max_mask_patches = high - mask_count delta = 0 for attempt in range(10): low = (min(H, W) // 3) ** 2 target_area = random.uniform(low, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < W and h < H: top = random.randint(0, H - h) left = random.randint(0, W - w) num_masked = mask[top: top + h, left: left + w].sum() if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break if delta == 0: break else: mask_count += delta elif self.pred_shape == 'rand': mask = np.hstack([ np.zeros(H * W - int(high)), np.ones(int(high)), ]).astype(bool) np.random.shuffle(mask) mask = mask.reshape(H, W) else: # no implementation assert False masks.append(mask) return output + (masks,) ================================================ FILE: downstream_tasks/detection/models/__init__.py ================================================ from .vision_transformer import VisionTransformer, vit_tiny, vit_small, vit_base, vit_large from .swin_transformer import SwinTransformer, swin_tiny, swin_small, swin_base, swin_large ================================================ FILE: downstream_tasks/detection/models/head.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import utils from utils import trunc_normal_ class CSyncBatchNorm(nn.SyncBatchNorm): def __init__(self, *args, with_var=False, **kwargs): super(CSyncBatchNorm, self).__init__(*args, **kwargs) self.with_var = with_var def forward(self, x): # center norm self.training = False if not self.with_var: self.running_var = torch.ones_like(self.running_var) normed_x = super(CSyncBatchNorm, self).forward(x) # udpate center self.training = True _ = super(CSyncBatchNorm, self).forward(x) return normed_x class PSyncBatchNorm(nn.SyncBatchNorm): def __init__(self, *args, bunch_size, **kwargs): procs_per_bunch = min(bunch_size, utils.get_world_size()) assert utils.get_world_size() % procs_per_bunch == 0 n_bunch = utils.get_world_size() // procs_per_bunch # ranks = list(range(utils.get_world_size())) print('---ALL RANKS----\n{}'.format(ranks)) rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)] print('---RANK GROUPS----\n{}'.format(rank_groups)) process_groups = [torch.distributed.new_group(pids) for pids in rank_groups] bunch_id = utils.get_rank() // procs_per_bunch process_group = process_groups[bunch_id] print('---CURRENT GROUP----\n{}'.format(process_group)) super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs) class CustomSequential(nn.Sequential): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) def forward(self, input): for module in self: dim = len(input.shape) if isinstance(module, self.bn_types) and dim > 2: perm = list(range(dim - 1)); perm.insert(1, dim - 1) inv_perm = list(range(dim)) + [1]; inv_perm.pop(1) input = module(input.permute(*perm)).permute(*inv_perm) else: input = module(input) return input class DINOHead(nn.Module): def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs): super().__init__() norm = self._build_norm(norm, hidden_dim) last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs) act = self._build_act(act) nlayers = max(nlayers, 1) if nlayers == 1: if bottleneck_dim > 0: self.mlp = nn.Linear(in_dim, bottleneck_dim) else: self.mlp = nn.Linear(in_dim, out_dim) else: layers = [nn.Linear(in_dim, hidden_dim)] if norm is not None: layers.append(norm) layers.append(act) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) if norm is not None: layers.append(norm) layers.append(act) if bottleneck_dim > 0: layers.append(nn.Linear(hidden_dim, bottleneck_dim)) else: layers.append(nn.Linear(hidden_dim, out_dim)) self.mlp = CustomSequential(*layers) self.apply(self._init_weights) if bottleneck_dim > 0: self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) if norm_last_layer: self.last_layer.weight_g.requires_grad = False else: self.last_layer = None self.last_norm = last_norm def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) if self.last_layer is not None: x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) if self.last_norm is not None: x = self.last_norm(x) return x def _build_norm(self, norm, hidden_dim, **kwargs): if norm == 'bn': norm = nn.BatchNorm1d(hidden_dim, **kwargs) elif norm == 'syncbn': norm = nn.SyncBatchNorm(hidden_dim, **kwargs) elif norm == 'csyncbn': norm = CSyncBatchNorm(hidden_dim, **kwargs) elif norm == 'psyncbn': norm = PSyncBatchNorm(hidden_dim, **kwargs) elif norm == 'ln': norm = nn.LayerNorm(hidden_dim, **kwargs) else: assert norm is None, "unknown norm type {}".format(norm) return norm def _build_act(self, act): if act == 'relu': act = nn.ReLU() elif act == 'gelu': act = nn.GELU() else: assert False, "unknown act type {}".format(act) return act class iBOTHead(DINOHead): def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, shared_head=False, **kwargs): super(iBOTHead, self).__init__(*args, norm=norm, act=act, last_norm=last_norm, nlayers=nlayers, hidden_dim=hidden_dim, bottleneck_dim=bottleneck_dim, norm_last_layer=norm_last_layer, **kwargs) if not shared_head: if bottleneck_dim > 0: self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False)) self.last_layer2.weight_g.data.fill_(1) if norm_last_layer: self.last_layer2.weight_g.requires_grad = False else: self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) self.last_layer2 = None self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs) else: if bottleneck_dim > 0: self.last_layer2 = self.last_layer else: self.mlp2 = self.mlp[-1] self.last_layer2 = None self.last_norm2 = self.last_norm def forward(self, x): if len(x.shape) == 2: return super(iBOTHead, self).forward(x) if self.last_layer is not None: x = self.mlp(x) x = nn.functional.normalize(x, dim=-1, p=2) x1 = self.last_layer(x[:, 0]) x2 = self.last_layer2(x[:, 1:]) else: x = self.mlp[:-1](x) x1 = self.mlp[-1](x[:, 0]) x2 = self.mlp2(x[:, 1:]) if self.last_norm is not None: x1 = self.last_norm(x1) x2 = self.last_norm2(x2) return x1, x2 ================================================ FILE: downstream_tasks/detection/models/swin_transformer.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from Swin-Transformer libarary: https://github.com/facebookresearch/dino https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py """ import os import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from math import sqrt from functools import partial from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super(Mlp, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super(WindowAttention, self).__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn_out = attn attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn_out def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops @staticmethod def compute_macs(module, input, output): B, N, C = input[0].shape module.__flops__ += module.flops(N) * B class SwinTransformerBlock(nn.Module): r"""Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = input_resolution[0] self.W = input_resolution[1] self.attn_mask_dict = {} # {self.H: self.create_attn_mask(self.H, self.W)} def create_attn_mask(self, H, W): # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def forward(self, x): B, L, C = x.shape H = int(sqrt(L)) W = H shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) if H is self.attn_mask_dict.keys(): attn_mask = self.attn_mask_dict[H] else: self.attn_mask_dict[H] = self.create_attn_mask(H, W).to(x.device) attn_mask = self.attn_mask_dict[H] else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x, attn def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape H = int(sqrt(L)) W = H x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: x, _ = blk(x) if self.downsample is not None: x = self.downsample(x) return x def forward_with_features(self, x): fea = [] for blk in self.blocks: x, _ = blk(x) fea.append(x) if self.downsample is not None: x = self.downsample(x) return x, fea def forward_with_attention(self, x): attns = [] for blk in self.blocks: x, attn = blk(x) attns.append(attn) if self.downsample is not None: x = self.downsample(x) return x, attns def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): # # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x.transpose(1, 2).reshape(B, C, H, W) def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. patch_size (int | tuple(int)): Patch size. in_chans (int): Number of input channels. num_classes (int): Number of classes for classification head. embed_dim (int): Embedding dimension. depths (tuple(int)): Depth of Swin Transformer layers. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. drop_path_rate (float): Stochastic depth rate. norm_layer (nn.Module): normalization layer. ape (bool): If True, add absolute position embedding to the patch embedding. patch_norm (bool): If True, add normalization after patch embedding. """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True, return_all_tokens=False, use_mean_pooling=True, masked_im_modeling=False): super().__init__() self.num_classes = num_classes self.depths = depths self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio self.return_all_tokens = return_all_tokens self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) # masked image modeling self.masked_im_modeling = masked_im_modeling if masked_im_modeling: self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim)) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): # todo: to be implemented return {'relative_position_bias_table'} def forward(self, x, return_all_tokens=None, mask=None): # patch linear embedding x = self.patch_embed(x) # mask image modeling if mask is not None: x = self.mask_model(x, mask) x = x.flatten(2).transpose(1, 2) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x_region = self.norm(x) # B L C x = self.avgpool(x_region.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return_all_tokens = self.return_all_tokens if \ return_all_tokens is None else return_all_tokens if return_all_tokens: return torch.cat([x.unsqueeze(1), x_region], dim=1) return x def get_selfattention(self, x, n=1): # n=1 return the last layer attn map; otherwise return attn maps in all layers x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) if n==1: return self.get_last_selfattention(x) else: return self.get_all_selfattention(x) def get_last_selfattention(self, x): for i, layer in enumerate(self.layers): if i < len(self.layers) - 1: x = layer(x) else: x, attns = layer.forward_with_attention(x) return attns[-1] def get_all_selfattention(self, x): attn_out = [] for layer in self.layers: x, attns = layer.forward_with_attention(x) attn_out += attns return attn_out def get_intermediate_layers(self, x, n=1, return_patch_avgpool=False): num_blks = sum(self.depths) start_idx = num_blks - n sum_cur = 0 for i, d in enumerate(self.depths): sum_cur_new = sum_cur + d if start_idx >= sum_cur and start_idx < sum_cur_new: start_stage = i start_blk = start_idx - sum_cur sum_cur = sum_cur_new x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) # we will return the averaged token features from the `n` last blocks # note: there is no [CLS] token in Swin Transformer output = [] s = 0 for i, layer in enumerate(self.layers): x, fea = layer.forward_with_features(x) if i >= start_stage: for x_ in fea[start_blk:]: if i == len(self.layers)-1: # use the norm in the last stage x_ = self.norm(x_) x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C if return_patch_avgpool: x_o = x_avg else: x_o = torch.cat((x_avg.unsqueeze(1), x_), dim=1) # print(f'Stage {i}, x_o {x_o.shape}') output.append(x_o) start_blk = 0 #return torch.cat(output, dim=-1) return output def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() if dist.get_rank() == 0: print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}") flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] is '*' or 'relative_position_index' not in k or 'attn_mask' not in k ) if need_init: if verbose: logging.info(f'=> init {k} from {pretrained}') if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): relative_position_bias_table_pretrained = v relative_position_bias_table_current = model_dict[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logging.info(f"Error in loading {k}, passing") else: if L1 != L2: logging.info( '=> load_pretrained: resized variant: {} to {}' .format((L1, nH1), (L2, nH2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): absolute_pos_embed_pretrained = v absolute_pos_embed_current = model_dict[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logging.info(f"Error in loading {k}, passing") else: if L1 != L2: logging.info( '=> load_pretrained: resized variant: {} to {}' .format((1, L1, C1), (1, L2, C2)) ) S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) def freeze_pretrained_layers(self, frozen_layers=[]): for name, module in self.named_modules(): if ( name.split('.')[0] in frozen_layers or '.'.join(name.split('.')[0:2]) in frozen_layers or (len(frozen_layers) > 0 and frozen_layers[0] is '*') ): for _name, param in module.named_parameters(): param.requires_grad = False logging.info( '=> set param {} requires grad to False' .format(name) ) for name, param in self.named_parameters(): if ( name.split('.')[0] in frozen_layers or (len(frozen_layers) > 0 and frozen_layers[0] is '*') and param.requires_grad is True ): param.requires_grad = False logging.info( '=> set param {} requires grad to False' .format(name) ) return self def get_num_layers(self): #return len(self.layers) return sum(self.depths) def mask_model(self, x, mask): # extend mask for hierarchical features if x.shape[-2:] != mask.shape[-2:]: htimes, wtimes = np.array(x.shape[-2:]) // np.array(mask.shape[-2:]) mask = mask.repeat_interleave(htimes, -2).repeat_interleave(wtimes, -1) # mask embed x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype) return x @register_model def swin_tiny(window_size=7, **kwargs): model = SwinTransformer( window_size=window_size, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.1), **kwargs) return model @register_model def swin_small(window_size=7, **kwargs): model = SwinTransformer( window_size=window_size, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs) return model @register_model def swin_base(window_size=7, **kwargs): model = SwinTransformer( window_size=window_size, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs) return model @register_model def swin_large(window_size=7, **kwargs): model = SwinTransformer( window_size=window_size, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], mlp_ratio=4, qkv_bias=True, drop_path_rate=kwargs.pop('drop_path_rate', 0.2), **kwargs) return model ================================================ FILE: downstream_tasks/detection/models/vision_transformer.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from DINO and timm library: https://github.com/facebookresearch/dino https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ import math import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from utils import trunc_normal_ from timm.models.registry import register_model def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 #self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) ####add by wxd self.qkv = nn.Linear(dim, dim * 3, bias=False) all_head_dim = head_dim * self.num_heads if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, x_rel_pos_bias = None): B, N, C = x.shape #qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) ########add by wxd qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale # if self.relative_position_bias_table is not None: # relative_position_bias = \ # self.relative_position_bias_table[self.relative_position_index.view(-1)].view( # self.window_size[0] * self.window_size[1] + 1, # self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH # print("################before relative:", relative_position_bias.shape) # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww # print("################after relative:", relative_position_bias.shape) # relative_position_bias = intepolate_rpe(relative_position_bias) # print("################ater inter relative:", relative_position_bias.shape) if x_rel_pos_bias is not None: attn = attn + x_rel_pos_bias.unsqueeze(0) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, init_values=0): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, x_rel_pos_bias=None, return_attention=False): y, attn = self.attn(self.norm1(x), x_rel_pos_bias) if return_attention: return attn if self.gamma_1 is None: x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * y) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x #class PatchEmbed(nn.Module): # """ Image to Patch Embedding # """ # def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): # super().__init__() # num_patches = (img_size // patch_size) * (img_size // patch_size) # self.img_size = img_size # self.patch_size = patch_size # self.num_patches = num_patches # print("#################patch in!!!") # self.patch_shape = (img_size // patch_size, img_size // patch_size) # # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # # def forward(self, x): # B, C, H, W = x.shape # return self.proj(x) class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.num_patches_w = img_size[0] // patch_size self.num_patches_h = img_size[1] // patch_size num_patches = self.num_patches_w * self.num_patches_h self.patch_shape = (img_size[0] // patch_size, img_size[1] // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches print("##############patch here!!!") self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, mask=None): B, C, H, W = x.shape return self.proj(x) class VisionTransformer(nn.Module): """ Vision Transformer """ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False, init_values=0, use_sincos_pos_emb=False, use_abs_pos_emb=False, use_rel_pos_bias=False, use_mean_pooling=False, masked_im_modeling=False): super().__init__() self.num_features = self.embed_dim = embed_dim self.return_all_tokens = return_all_tokens print("############use_abs_pos:", use_abs_pos_emb) print("############use_sincos_pos:", use_sincos_pos_emb) print("############use_rel_pos_bias:", use_rel_pos_bias) self.use_abs_pos_emb = use_abs_pos_emb self.use_sincos_pos_emb = use_sincos_pos_emb self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) self.use_rel_pos_bias = use_rel_pos_bias if self.use_rel_pos_bias: print("=================use RelativePositionBias===================") window_size=self.patch_embed.patch_shape self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() if use_abs_pos_emb: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) # masked image modeling self.masked_im_modeling = masked_im_modeling if masked_im_modeling: self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim)) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 print("############self.pos_embed:", self.pos_embed) N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_embed.patch_size h0 = h // self.patch_embed.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def prepare_tokens(self, x, mask=None): B, nc, w, h = x.shape # patch linear embedding x = self.patch_embed(x) # mask image modeling if mask is not None: x = self.mask_model(x, mask) x = x.flatten(2).transpose(1, 2) # add the [CLS] token to the embed patch tokens cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # add positional encoding to each token if self.pos_embed is not None: x = x + self.interpolate_pos_encoding(x, w, h) return self.pos_drop(x) def forward(self, x, return_all_tokens=None, mask=None): # mim if self.masked_im_modeling: assert mask is not None x = self.prepare_tokens(x, mask=mask) else: x = self.prepare_tokens(x) for blk in self.blocks: x = blk(x) x = self.norm(x) if self.fc_norm is not None: x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1)) return_all_tokens = self.return_all_tokens if \ return_all_tokens is None else return_all_tokens if return_all_tokens: return x return x[:, 0] def get_last_selfattention(self, x): x = self.prepare_tokens(x) for i, blk in enumerate(self.blocks): if i < len(self.blocks) - 1: x = blk(x) else: # return attention of the last block return blk(x, return_attention=True) def get_intermediate_layers(self, x, n=1): x = self.prepare_tokens(x) # we return the output tokens from the `n` last blocks output = [] for i, blk in enumerate(self.blocks): x = blk(x) if len(self.blocks) - i <= n: output.append(self.norm(x)) return output def get_num_layers(self): return len(self.blocks) def mask_model(self, x, mask): x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype) return x def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) return model def vit_small(patch_size=16, **kwargs): model = VisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) return model def vit_base(patch_size=16, **kwargs): model = VisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) return model def vit_large(patch_size=16, **kwargs): model = VisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs) return model ================================================ FILE: downstream_tasks/detection/scripts/run_eval.sh ================================================ #!/usr/bin/env bash echo "EVAL MODEL:"$MODEL python -m torch.distributed.launch --nproc_per_node=8 \ evaluation/object_detection/test.py \ $CONFIG \ $MODEL \ --launcher pytorch \ --eval bbox segm \ --cfg-options model.backbone.use_checkpoint=True \ ${@:6} ================================================ FILE: downstream_tasks/detection/scripts/run_train_maskrcnn_vit_base.sh ================================================ #!/usr/bin/env bash python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=$PORT \ evaluation/object_detection/train.py \ evaluation/object_detection/configs/mask_rcnn/vit_base_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00003.py \ --launcher pytorch \ --work-dir $OUTPUT_DIR \ --no-validate \ --deterministic \ --cfg-options model.backbone.use_checkpoint=True \ model.pretrained=$PRETRAINED \ ${@:6} ================================================ FILE: downstream_tasks/detection/scripts/run_train_maskrcnn_vit_large.sh ================================================ #!/usr/bin/env bash python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=$PORT \ evaluation/object_detection/train.py \ evaluation/object_detection/configs/mask_rcnn/vit_large_giou_4conv1f_coco_maskrcnn_1x_cae_sincos_init0.1_lr00002_lrdr0.85_dp0.2.py \ --launcher pytorch \ --work-dir $OUTPUT_DIR \ --no-validate \ --deterministic \ --cfg-options model.backbone.use_checkpoint=True \ model.pretrained=$PRETRAINED \ ${@:6} ================================================ FILE: downstream_tasks/detection/utils.py ================================================ # Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Mostly copy-paste from torchvision references or other public repos like DETR: https://github.com/facebookresearch/detr/blob/master/util/misc.py """ import os import sys import time import math import json import random import datetime import subprocess import numpy as np import torch import torch.distributed as dist from collections import defaultdict, deque from pathlib import Path from torch import nn from PIL import ImageFilter, ImageOps, Image, ImageDraw class GaussianBlur(object): """ Apply Gaussian Blur to the PIL image. """ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): self.prob = p self.radius_min = radius_min self.radius_max = radius_max def __call__(self, img): do_it = random.random() <= self.prob if not do_it: return img return img.filter( ImageFilter.GaussianBlur( radius=random.uniform(self.radius_min, self.radius_max) ) ) class Solarization(object): """ Apply Solarization to the PIL image. """ def __init__(self, p): self.p = p def __call__(self, img): if random.random() < self.p: return ImageOps.solarize(img) else: return img class PermutePatch(object): """ Apply Patch permutation to the PIL image. """ def __init__(self, psz): self.psz = psz def __call__(self, img): imgs = [] imgwidth, imgheight = img.size for i in range(0, imgheight, self.psz): for j in range(0, imgwidth, self.psz): box = (j, i, j+self.psz, i+self.psz) imgs.append(img.crop(box)) random.shuffle(imgs) new_img = Image.new('RGB', (imgwidth, imgheight)) k = 0 for i in range(0, imgheight, self.psz): for j in range(0, imgwidth, self.psz): new_img.paste(imgs[k], (j, i)) k += 1 return new_img class HideAndSeek(object): """ Apply Patch permutation to the PIL image. """ def __init__(self, ratio, psz): self.ratio = ratio self.psz = psz def __call__(self, img): imgwidth, imgheight = img.size numw, numh = imgwidth // self.psz, imgheight // self.psz mask_num = int(numw * numh * self.ratio) mask_patch = np.random.choice(np.arange(numw * numh), mask_num, replace=False) mask_w, mask_h = mask_patch % numh, mask_patch // numh # img.save('test1.png') draw = ImageDraw.Draw(img) for mw, mh in zip(mask_w, mask_h): draw.rectangle((mw * self.psz, mh * self.psz, (mw + 1) * self.psz, (mh + 1) * self.psz), fill="black") # img.save('test2.png') return img def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): if os.path.isfile(pretrained_weights): state_dict = torch.load(pretrained_weights, map_location="cpu") if checkpoint_key is not None and checkpoint_key in state_dict: print(f"Take key {checkpoint_key} in provided checkpoint dict") state_dict = state_dict[checkpoint_key] # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} msg = model.load_state_dict(state_dict, strict=False) print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) return elif pretrained_weights == 'download': url = None if model_name == "vit_small" and patch_size == 16: url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" elif model_name == "vit_small" and patch_size == 8: url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" elif model_name == "vit_base" and patch_size == 16: url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" elif model_name == "vit_base" and patch_size == 8: url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" if url is not None: print("Since no pretrained weights are provided, we load the pretrained weights from {}.".format(url)) state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) model.load_state_dict(state_dict, strict=True) return elif pretrained_weights == 'supervised': url = None if model_name == "vit_small" and patch_size == 16: url = "deit_small_patch16_224-cd65a155.pth" elif model_name == "vit_base" and patch_size == 16: url = "deit_base_patch16_224-b5f2ef4d.pth" if url is not None: print("Since no pretrained weights are provided, we load the pretrained weights from {}.".format(url)) state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/" + url) msg = model.load_state_dict(state_dict['model'], strict=False) print('Supervised weights found at {} and loaded with msg: {}'.format(url, msg)) return print("There is no reference weights available for this model => We use random weights.") def clip_gradients(model, clip): norms = [] for name, p in model.named_parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) norms.append(param_norm.item()) clip_coef = clip / (param_norm + 1e-6) if clip_coef < 1: p.grad.data.mul_(clip_coef) return norms def cancel_gradients_last_layer(epoch, model, freeze_last_layer): if epoch >= freeze_last_layer: return for n, p in model.named_parameters(): if "last_layer" in n: p.grad = None def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): """ Re-start from checkpoint """ if not os.path.isfile(ckp_path): return print("Found checkpoint at {}".format(ckp_path)) # open checkpoint file checkpoint = torch.load(ckp_path, map_location="cpu") # key is what to look for in the checkpoint file # value is the object to load # example: {'state_dict': model} for key, value in kwargs.items(): if key in checkpoint and value is not None: try: msg = value.load_state_dict(checkpoint[key], strict=False) print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) except TypeError: try: msg = value.load_state_dict(checkpoint[key]) print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) except ValueError: print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path)) else: print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) # re load variable important for the run if run_variables is not None: for var_name in run_variables: if var_name in checkpoint: run_variables[var_name] = checkpoint[var_name] def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def bool_flag(s): """ Parse boolean arguments from the command line. """ FALSY_STRINGS = {"off", "false", "0"} TRUTHY_STRINGS = {"on", "true", "1"} if s.lower() in FALSY_STRINGS: return False elif s.lower() in TRUTHY_STRINGS: return True else: raise argparse.ArgumentTypeError("invalid value for a boolean flag") def fix_random_seeds(seed=31): """ Fix random seeds. """ random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.6f} ({global_avg:.6f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.6f}') data_time = SmoothedValue(fmt='{avg:.6f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' if torch.cuda.is_available(): log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' ]) else: log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ]) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.6f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() sha = 'N/A' diff = "clean" branch = 'N/A' try: sha = _run(['git', 'rev-parse', 'HEAD']) subprocess.check_output(['git', 'diff'], cwd=cwd) diff = _run(['git', 'diff-index', 'HEAD']) diff = "has uncommited changes" if diff else "clean" branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def init_distributed_mode(args): # launched with torch.distributed.launch if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) # launched with submitit on a slurm cluster elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() # launched naively with `python main_dino.py` # we manually add MASTER_ADDR and MASTER_PORT to env variables elif torch.cuda.is_available(): print('Will run the code on one GPU.') args.rank, args.gpu, args.world_size = 0, 0, 1 os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' else: print('Does not support training without GPU.') sys.exit(1) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) torch.cuda.set_device(args.gpu) print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) dist.barrier() setup_for_distributed(args.rank == 0) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) class LARS(torch.optim.Optimizer): """ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py """ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, weight_decay_filter=None, lars_adaptation_filter=None): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, eta=eta, weight_decay_filter=weight_decay_filter, lars_adaptation_filter=lars_adaptation_filter) super().__init__(params, defaults) @torch.no_grad() def step(self): for g in self.param_groups: for p in g['params']: dp = p.grad if dp is None: continue if p.ndim != 1: dp = dp.add(p, alpha=g['weight_decay']) if p.ndim != 1: param_norm = torch.norm(p) update_norm = torch.norm(dp) one = torch.ones_like(param_norm) q = torch.where(param_norm > 0., torch.where(update_norm > 0, (g['eta'] * param_norm / update_norm), one), one) dp = dp.mul(q) param_state = self.state[p] if 'mu' not in param_state: param_state['mu'] = torch.zeros_like(p) mu = param_state['mu'] mu.mul_(g['momentum']).add_(dp) p.add_(mu, alpha=-g['lr']) def create_ds_config(args): args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") with open(args.deepspeed_config, mode="w") as writer: ds_config = { "train_batch_size": args.batch_size * get_world_size(), "train_micro_batch_size_per_gpu": args.batch_size, "steps_per_print": 1000, "optimizer": { "type": "Adam", "adam_w_mode": True, "params": { "lr": args.lr, "weight_decay": args.weight_decay, "bias_correction": True, "betas": [ 0.9, 0.999 ], "eps": 1e-8 } }, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 7, "loss_scale_window": 128 } } writer.write(json.dumps(ds_config, indent=2)) class MultiCropWrapper(nn.Module): """ Perform forward pass separately on each resolution input. The inputs corresponding to a single resolution are clubbed and single forward is run on the same resolution inputs. Hence we do several forward passes = number of different resolutions used. We then concatenate all the output features and run the head forward on these concatenated features. """ def __init__(self, backbone, head=None): super(MultiCropWrapper, self).__init__() # disable layers dedicated to ImageNet labels classification backbone.fc, backbone.head = nn.Identity(), nn.Identity() self.backbone = backbone if head is None: self.head = nn.Identity() else: self.head = head def forward(self, x, mask=None, return_backbone_feat=False, **kwargs): # convert to list if not isinstance(x, list): x = [x] mask = [mask] if mask is not None else None idx_crops = torch.cumsum(torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in x]), return_counts=True, )[1], 0) start_idx = 0 for end_idx in idx_crops: inp_x = torch.cat(x[start_idx: end_idx]) if mask is not None: inp_m = torch.cat(mask[start_idx: end_idx]) kwargs.update(dict(mask=inp_m)) _out = self.backbone(inp_x, **kwargs) if start_idx == 0: output = _out else: output = torch.cat((output, _out)) start_idx = end_idx # Run the head forward on the concatenated features. output_ = self.head(output) if return_backbone_feat: return output, output_ return output_ def get_params_groups(model): regularized = [] not_regularized = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # we do not regularize biases nor Norm parameters if name.endswith(".bias") or len(param.shape) == 1: not_regularized.append(param) else: regularized.append(param) return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] def has_batchnorms(model): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) for name, module in model.named_modules(): if isinstance(module, bn_types): return True return False def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output class PCA(): """ Class to compute and apply PCA. """ def __init__(self, dim=256, whit=0.5): self.dim = dim self.whit = whit self.mean = None def train_pca(self, cov): """ Takes a covariance matrix (np.ndarray) as input. """ d, v = np.linalg.eigh(cov) eps = d.max() * 1e-5 n_0 = (d < eps).sum() if n_0 > 0: d[d < eps] = eps # total energy totenergy = d.sum() # sort eigenvectors with eigenvalues order idx = np.argsort(d)[::-1][:self.dim] d = d[idx] v = v[:, idx] print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0)) # for the whitening d = np.diag(1. / d**self.whit) # principal components self.dvt = np.dot(d, v.T) def apply(self, x): # input is from numpy if isinstance(x, np.ndarray): if self.mean is not None: x -= self.mean return np.dot(self.dvt, x.T).T # input is from torch and is on GPU if x.is_cuda: if self.mean is not None: x -= torch.cuda.FloatTensor(self.mean) return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) # input if from torch, on CPU if self.mean is not None: x -= torch.FloatTensor(self.mean) return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) def compute_ap(ranks, nres): """ Computes average precision for given ranked indexes. Arguments --------- ranks : zerro-based ranks of positive images nres : number of positive images Returns ------- ap : average precision """ # number of images ranked by the system nimgranks = len(ranks) # accumulate trapezoids in PR-plot ap = 0 recall_step = 1. / nres for j in np.arange(nimgranks): rank = ranks[j] if rank == 0: precision_0 = 1. else: precision_0 = float(j) / rank precision_1 = float(j + 1) / (rank + 1) ap += (precision_0 + precision_1) * recall_step / 2. return ap def compute_map(ranks, gnd, kappas=[]): """ Computes the mAP for a given set of returned results. Usage: map = compute_map (ranks, gnd) computes mean average precsion (map) only map, aps, pr, prs = compute_map (ranks, gnd, kappas) computes mean average precision (map), average precision (aps) for each query computes mean precision at kappas (pr), precision at kappas (prs) for each query Notes: 1) ranks starts from 0, ranks.shape = db_size X #queries 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 3) If there are no positive images for some query, that query is excluded from the evaluation """ map = 0. nq = len(gnd) # number of queries aps = np.zeros(nq) pr = np.zeros(len(kappas)) prs = np.zeros((nq, len(kappas))) nempty = 0 for i in np.arange(nq): qgnd = np.array(gnd[i]['ok']) # no positive images, skip from the average if qgnd.shape[0] == 0: aps[i] = float('nan') prs[i, :] = float('nan') nempty += 1 continue try: qgndj = np.array(gnd[i]['junk']) except: qgndj = np.empty(0) # sorted positions of positive and junk images (0 based) pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] k = 0; ij = 0; if len(junk): # decrease positions of positives based on the number of # junk images appearing before them ip = 0 while (ip < len(pos)): while (ij < len(junk) and pos[ip] > junk[ij]): k += 1 ij += 1 pos[ip] = pos[ip] - k ip += 1 # compute ap ap = compute_ap(pos, len(qgnd)) map = map + ap aps[i] = ap # compute precision @ k pos += 1 # get it to 1-based for j in np.arange(len(kappas)): kq = min(max(pos), kappas[j]); prs[i, j] = (pos <= kq).sum() / kq pr = pr + prs[i, :] map = map / (nq - nempty) pr = pr / (nq - nempty) return map, aps, pr, prs ================================================ FILE: downstream_tasks/semantic_segmentation/README.md ================================================ # ADE20k Semantic segmentation with CAE ## Getting started 1. Install the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library and some required packages. ```bash pip install mmcv-full==1.3.0 mmsegmentation==0.11.0 pip install scipy timm==0.3.2 ``` 2. Install [apex](https://github.com/NVIDIA/apex) for mixed-precision training ```bash git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ ``` 3. Follow the guide in [mmseg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/dataset_prepare.md) to prepare the ADE20k dataset. ## Fine-tuning Command format: ``` tools/dist_train.sh --work-dir --seed 0 --deterministic --options model.pretrained= ``` For example, using a CAE-base backbone with UperNet: ```bash bash tools/dist_train.sh \ configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py 8 \ --work-dir /path/to/save --seed 0 --deterministic \ --options model.pretrained= ``` More config files can be found at [`configs_local/cae/upernet`](configs_local/cae/upernet). ## Evaluation Command format: ``` tools/dist_test.sh --eval mIoU ``` For example, evaluate a CAE-base backbone with UperNet: ```bash bash tools/dist_test.sh configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_4e-4.py \ 8 --eval mIoU ``` Please note that, the evaluation will be automatically conducted during training. ## Results (pretrined models are trained on ImageNet-1K without label) | Backbone | #Pretrained Epoch | mIoU | Config | | -------- | ----------------- | ---- | ---------------------------------------- | | ViT-B | 300 | 48.1 | [3e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py) | | ViT-B | 800 | 49.7 | [2e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py) | | ViT-B | 1600 | 50.3 | [1e-4](./configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py) | | ViT-L | 1600 | 54.9 | [4e-5](./configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py) | We find that, if the pretrained model is better, a smaller learning rate is more suitable. However, different learning rates will not lead to significantly different results. For example, 800-epoch pretrained ViT-B could obtain 49.6 mIoU (averaged from two runs) with lr=4e-4. ## Acknowledgment This code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library, the [Swin](https://github.com/microsoft/Swin-Transformer) repository, [XCiT](https://github.com/facebookresearch/xcit) and the [SETR](https://github.com/fudan-zvg/SETR) repository. ================================================ FILE: downstream_tasks/semantic_segmentation/backbone/beit.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' import math import torch from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ import numpy as np from mmcv_custom import load_checkpoint from mmseg.utils import get_root_logger from mmseg.models.builder import BACKBONES class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.0) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) return x, (Hp, Wp) class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww def get_sinusoid_encoding_table(n_position, d_hid, token=False): ''' Sinusoid position encoding table ''' def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if token: sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0) return torch.FloatTensor(sinusoid_table).unsqueeze(0) @BACKBONES.register_module() class BEiT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False, out_indices=[3, 5, 7, 11], out_with_norm=False): super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.out_indices = out_indices self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: # self.pos_embed = None # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) if use_sincos_pos_embed: self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) self.out_indices = out_indices if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) if not out_with_norm: self.norm = nn.Identity() else: self.norm = norm_layer(embed_dim) self.apply(self._init_weights) self.fix_init_weight() def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: ''' if self.use_abs_pos_emb: x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach() else: x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach() x = torch.cat([x[:,:1],x],dim=1) ''' x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None features = [] for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) if i in self.out_indices: xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp) # xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) return tuple(features) def forward(self, x): x = self.forward_features(x) return x ================================================ FILE: downstream_tasks/semantic_segmentation/backbone/beit_fapn.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' import math import torch from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ from mmcv_custom import load_checkpoint from mmseg.utils import get_root_logger from mmseg.models.builder import BACKBONES from mmcv.ops import DeformConv2d from mmcv.cnn import xavier_init class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.0) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) return x, (Hp, Wp) class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww class FeatureSelectionModule(nn.Module): def __init__(self, in_c, out_c, norm="GM"): super(FeatureSelectionModule, self).__init__() self.conv_attn = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False).cuda(), # without norm and activation ) self.sigmoid = nn.Sigmoid().cuda() self.conv = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False).cuda(), nn.BatchNorm2d(out_c).cuda(), nn.ReLU(inplace=True).cuda() ) xavier_init(self.conv_attn) for m in self.conv: if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform') def forward(self, x): attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:]))) feat = torch.mul(x, attn) x = x + feat feat = self.conv(x) return feat class FeatureAlign(nn.Module): def __init__(self, in_c, out_c, norm=None): super(FeatureAlign, self).__init__() self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm="") self.relu = nn.ReLU(inplace=True).cuda() self.offset = nn.Conv2d(out_c*2, 144, kernel_size=1, stride=1, padding=0, bias=False).cuda() # 144=kernel_size[0]*kernel_size[1]*deform_groups*2 self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8).cuda() def forward(self, feat_l, feat_s): feat_l = feat_l.float() feat_s = feat_s.float() HW = feat_l.size()[2:] if feat_l.size()[2:] != feat_s.size()[2:]: feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False) else: feat_up = feat_s feat_arm = self.lateral_conv(feat_l) offset = self.offset(torch.cat([feat_arm, feat_up], dim=1)).float() feat_align = self.relu(self.deform_conv2d(feat_up, offset)) return feat_align + feat_arm @BACKBONES.register_module() class BEiT_FaPN(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, out_indices=[3, 5, 7, 11]): super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.out_conv = [] self.FaPN = [] for i in range(len(out_indices)): self.out_conv.append(nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=1, padding=1, bias=False).cuda()) self.FaPN.append(FeatureAlign(embed_dim, embed_dim)) if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.out_indices = out_indices self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) self.out_indices = out_indices if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) self.apply(self._init_weights) self.fix_init_weight() def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None features = [] for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) if i in self.out_indices: xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) features_fapn = [] features_fapn.append(features[-1]) for i in range(len(ops)-1, 0, -1): new_feature = self.FaPN[i](features[i-1], features[i]) new_feature = self.out_conv[i](new_feature) features_fapn.append(new_feature) features_fapn.reverse() return tuple(features_fapn) def forward(self, x): x = self.forward_features(x) return x ================================================ FILE: downstream_tasks/semantic_segmentation/backbone/cae.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' import math import torch from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ import numpy as np from mmcv_custom import load_checkpoint from mmseg.utils import get_root_logger from mmseg.models.builder import BACKBONES class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.0) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) return x, (Hp, Wp) class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww def get_sinusoid_encoding_table(n_position, d_hid, token=False): ''' Sinusoid position encoding table ''' def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if token: sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0) return torch.FloatTensor(sinusoid_table).unsqueeze(0) @BACKBONES.register_module() class CAE(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_sincos_pos_embed=True, use_shared_rel_pos_bias=False, out_indices=[3, 5, 7, 11], out_with_norm=False): super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.out_indices = out_indices self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: # self.pos_embed = None # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) if use_sincos_pos_embed: self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) self.out_indices = out_indices if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) if not out_with_norm: self.norm = nn.Identity() else: self.norm = norm_layer(embed_dim) self.apply(self._init_weights) self.fix_init_weight() def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: ''' if self.use_abs_pos_emb: x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach() else: x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach() x = torch.cat([x[:,:1],x],dim=1) ''' x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None features = [] for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) if i in self.out_indices: xp = self.norm(x[:, 1:, :]).permute(0, 2, 1).reshape(B, -1, Hp, Wp) # xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) return tuple(features) def forward(self, x): x = self.forward_features(x) return x ================================================ FILE: downstream_tasks/semantic_segmentation/backbone/fapn.py ================================================ class FeatureSelectionModule(nn.Module): def __init__(self, in_c, out_c, norm="GM"): super(FeatureSelectionModule, self).__init__() self.conv_attn = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, bias=False), # without norm and activation ) self.sigmoid = nn.Sigmoid() self.conv = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True) ) xavier_init(self.conv_attn) for m in self.conv.modeuls(): if isintance(m, nn.Conv2d): xavier_init(m, distribution='uniform') def forward(self, x): attn = self.sigmoid(self.conv_attn(F.avg_pool2d(x, x.size()[2:]))) feat = torch.mul(x, attn) x = x + feat feat = self.conv(x) return feat class FeatureAlign(nn.Module): def __init__(self, in_c, out_c, norm=None): super(FeatureAlign, self).__init__() self.lateral_conv = FeatureSelectionModule(in_c, out_c, norm="") self.relu = nn.ReLU(inplace=True) self.offset = nn.Conv2d(out_c * 2, out_c, kernel_size=1, stride=1, padding=0, bias=False) self.deform_conv2d = DeformConv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1, deform_groups=8) def forward(self, teat_l, feat_s): HW = feat_l.size()[2:] if feat_l.size()[2:] != feat_s.size()[2:]: feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False) else: feat_up = feat_s feat_arm = self.lateral_conv(feat_l) offset = self.offset(torch.cat([feat_arm, feat_up], dim=1)) feat_align = self.relu(self.deform_conv2d(feat_up, offset)) return feat_align + feat_arm ================================================ FILE: downstream_tasks/semantic_segmentation/backbone/mae.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' import math import torch from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ import numpy as np from mmcv_custom import load_checkpoint from mmseg.utils import get_root_logger from mmseg.models.builder import BACKBONES class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True) # if qkv_bias: # self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) # self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) # else: # self.q_bias = None # self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.0) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None # if self.q_bias is not None: # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) # qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) return x, (Hp, Wp) class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim) def forward(self, x): x = self.backbone(x)[-1] x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww def get_sinusoid_encoding_table(n_position, d_hid, token=False): ''' Sinusoid position encoding table ''' def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if token: sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0) return torch.FloatTensor(sinusoid_table).unsqueeze(0) @BACKBONES.register_module() class MAE(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, out_indices=[3, 5, 7, 11]): super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.out_indices = out_indices self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: # self.pos_embed = None # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim) self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) self.out_indices = out_indices if patch_size == 16: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), nn.SyncBatchNorm(embed_dim), nn.GELU(), nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) elif patch_size == 8: self.fpn1 = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) self.fpn2 = nn.Identity() self.fpn3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ) self.fpn4 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), ) self.apply(self._init_weights) self.fix_init_weight() def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., decode=False): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: ''' if self.use_abs_pos_emb: x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach() else: x = x[:,1:] + self.pos_embed.expand(batch_size, -1, -1).type_as(x[:,1:]).to(x.device).clone().detach() x = torch.cat([x[:,:1],x],dim=1) ''' x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None features = [] for i, blk in enumerate(self.blocks): if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) if i in self.out_indices: xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) features.append(xp.contiguous()) ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] for i in range(len(features)): features[i] = ops[i](features[i]) return tuple(features) def forward(self, x): x = self.forward_features(x) return x ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k.py ================================================ # dataset settings dataset_type = 'ADE20KDataset' data_root = 'data/ade/ADEChallengeData2016' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 512) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/ade20k_640x640.py ================================================ # dataset settings dataset_type = 'ADE20KDataset' data_root = 'data/ade/ADEChallengeData2016' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (640, 640) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2560, 640), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/chase_db1.py ================================================ # dataset settings dataset_type = 'ChaseDB1Dataset' data_root = 'data/CHASE_DB1' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) img_scale = (960, 999) crop_size = (128, 128) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type='RepeatDataset', times=40000, dataset=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline)), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes.py ================================================ # dataset settings dataset_type = 'CityscapesDataset' data_root = 'data/cityscapes/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 1024) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 1024), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/train', ann_dir='gtFine/train', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/val', ann_dir='gtFine/val', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/val', ann_dir='gtFine/val', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/cityscapes_769x769.py ================================================ _base_ = './cityscapes.py' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (769, 769) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2049, 1025), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/coco-stuff10k.py ================================================ # dataset settings dataset_type = 'COCOStuffDataset' data_root = 'data/coco_stuff10k' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 512) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, reduce_zero_label=True, img_dir='images/train2014', ann_dir='annotations/train2014', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, reduce_zero_label=True, img_dir='images/test2014', ann_dir='annotations/test2014', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, reduce_zero_label=True, img_dir='images/test2014', ann_dir='annotations/test2014', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/drive.py ================================================ # dataset settings dataset_type = 'DRIVEDataset' data_root = 'data/DRIVE' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) img_scale = (584, 565) crop_size = (64, 64) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type='RepeatDataset', times=40000, dataset=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline)), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/hrf.py ================================================ # dataset settings dataset_type = 'HRFDataset' data_root = 'data/HRF' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) img_scale = (2336, 3504) crop_size = (256, 256) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type='RepeatDataset', times=40000, dataset=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline)), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_context.py ================================================ # dataset settings dataset_type = 'PascalContextDataset' data_root = 'data/VOCdevkit/VOC2010/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) img_scale = (520, 520) crop_size = (480, 480) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClassContext', split='ImageSets/SegmentationContext/train.txt', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClassContext', split='ImageSets/SegmentationContext/val.txt', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClassContext', split='ImageSets/SegmentationContext/val.txt', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12.py ================================================ # dataset settings dataset_type = 'PascalVOCDataset' data_root = 'data/VOCdevkit/VOC2012' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 512) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClass', split='ImageSets/Segmentation/train.txt', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClass', split='ImageSets/Segmentation/val.txt', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='JPEGImages', ann_dir='SegmentationClass', split='ImageSets/Segmentation/val.txt', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/pascal_voc12_aug.py ================================================ _base_ = './pascal_voc12.py' # dataset settings data = dict( train=dict( ann_dir=['SegmentationClass', 'SegmentationClassAug'], split=[ 'ImageSets/Segmentation/train.txt', 'ImageSets/Segmentation/aug.txt' ])) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/datasets/stare.py ================================================ # dataset settings dataset_type = 'STAREDataset' data_root = 'data/STARE' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) img_scale = (605, 700) crop_size = (128, 128) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type='RepeatDataset', times=40000, dataset=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline)), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/default_runtime.py ================================================ # yapf:disable log_config = dict( interval=50, hooks=[ dict(type='TextLoggerHook', by_epoch=False), # dict(type='TensorboardLoggerHook') ]) # yapf:enable dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None workflow = [('train', 1)] cudnn_benchmark = True ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ann_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='ANNHead', in_channels=[1024, 2048], in_index=[2, 3], channels=512, project_channels=256, query_scales=(1, ), key_pool_scales=(1, 3, 6, 8), dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/apcnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='APCHead', in_channels=2048, in_index=3, channels=512, pool_scales=(1, 2, 3, 6), dropout_ratio=0.1, num_classes=19, norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ccnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='CCHead', in_channels=2048, in_index=3, channels=512, recurrence=2, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/cgnet.py ================================================ # model settings norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) model = dict( type='EncoderDecoder', backbone=dict( type='CGNet', norm_cfg=norm_cfg, in_channels=3, num_channels=(32, 64, 128), num_blocks=(3, 21), dilations=(2, 4), reductions=(8, 16)), decode_head=dict( type='FCNHead', in_channels=256, in_index=2, channels=256, num_convs=0, concat_input=False, dropout_ratio=0, num_classes=19, norm_cfg=norm_cfg, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, class_weight=[ 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, 10.396974, 10.055647 ])), # model training and testing settings train_cfg=dict(sampler=None), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/danet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='DAHead', in_channels=2048, in_index=3, channels=512, pam_channels=64, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='ASPPHead', in_channels=2048, in_index=3, channels=512, dilations=(1, 12, 24, 36), dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3_unet_s5-d16.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='UNet', in_channels=3, base_channels=64, num_stages=5, strides=(1, 1, 1, 1, 1), enc_num_convs=(2, 2, 2, 2, 2), dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), with_cp=False, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='InterpConv'), norm_eval=False), decode_head=dict( type='ASPPHead', in_channels=64, in_index=4, channels=16, dilations=(1, 12, 24, 36), dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=128, in_index=3, channels=64, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='slide', crop_size=256, stride=170)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/deeplabv3plus_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='DepthwiseSeparableASPPHead', in_channels=2048, in_index=3, channels=512, dilations=(1, 12, 24, 36), c1_in_channels=256, c1_channels=48, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/dmnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='DMHead', in_channels=2048, in_index=3, channels=512, filter_sizes=(1, 3, 5, 7), dropout_ratio=0.1, num_classes=19, norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/dnl_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='DNLHead', in_channels=2048, in_index=3, channels=512, dropout_ratio=0.1, reduction=2, use_scale=True, mode='embedded_gaussian', num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/emanet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='EMAHead', in_channels=2048, in_index=3, channels=256, ema_channels=512, num_bases=64, num_stages=3, momentum=0.1, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/encnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='EncHead', in_channels=[512, 1024, 2048], in_index=(1, 2, 3), channels=512, num_codes=32, use_se_loss=True, add_lateral=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_se_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fast_scnn.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) model = dict( type='EncoderDecoder', backbone=dict( type='FastSCNN', downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), global_block_strides=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, fusion_out_channels=128, out_indices=(0, 1, 2), norm_cfg=norm_cfg, align_corners=False), decode_head=dict( type='DepthwiseSeparableFCNHead', in_channels=128, channels=128, concat_input=False, num_classes=19, in_index=-1, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), auxiliary_head=[ dict( type='FCNHead', in_channels=128, channels=32, num_convs=1, num_classes=19, in_index=-2, norm_cfg=norm_cfg, concat_input=False, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), dict( type='FCNHead', in_channels=64, channels=32, num_convs=1, num_classes=19, in_index=-3, norm_cfg=norm_cfg, concat_input=False, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), ], # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_hr18.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://msra/hrnetv2_w18', backbone=dict( type='HRNet', norm_cfg=norm_cfg, norm_eval=False, extra=dict( stage1=dict( num_modules=1, num_branches=1, block='BOTTLENECK', num_blocks=(4, ), num_channels=(64, )), stage2=dict( num_modules=1, num_branches=2, block='BASIC', num_blocks=(4, 4), num_channels=(18, 36)), stage3=dict( num_modules=4, num_branches=3, block='BASIC', num_blocks=(4, 4, 4), num_channels=(18, 36, 72)), stage4=dict( num_modules=3, num_branches=4, block='BASIC', num_blocks=(4, 4, 4, 4), num_channels=(18, 36, 72, 144)))), decode_head=dict( type='FCNHead', in_channels=[18, 36, 72, 144], in_index=(0, 1, 2, 3), channels=sum([18, 36, 72, 144]), input_transform='resize_concat', kernel_size=1, num_convs=1, concat_input=False, dropout_ratio=-1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='FCNHead', in_channels=2048, in_index=3, channels=512, num_convs=2, concat_input=True, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fcn_unet_s5-d16.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='UNet', in_channels=3, base_channels=64, num_stages=5, strides=(1, 1, 1, 1, 1), enc_num_convs=(2, 2, 2, 2, 2), dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), with_cp=False, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='InterpConv'), norm_eval=False), decode_head=dict( type='FCNHead', in_channels=64, in_index=4, channels=64, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=128, in_index=3, channels=64, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='slide', crop_size=256, stride=170)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/fpn_r50.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 1, 1), strides=(1, 2, 2, 2), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=4), decode_head=dict( type='FPNHead', in_channels=[256, 256, 256, 256], in_index=[0, 1, 2, 3], feature_strides=[4, 8, 16, 32], channels=128, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/gcnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='GCHead', in_channels=2048, in_index=3, channels=512, ratio=1 / 4., pooling_type='att', fusion_types=('channel_add', ), dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/lraspp_m-v3-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) model = dict( type='EncoderDecoder', backbone=dict( type='MobileNetV3', arch='large', out_indices=(1, 3, 16), norm_cfg=norm_cfg), decode_head=dict( type='LRASPPHead', in_channels=(16, 24, 960), in_index=(0, 1, 2), channels=128, input_transform='multiple_select', dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/nonlocal_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='NLHead', in_channels=2048, in_index=3, channels=512, dropout_ratio=0.1, reduction=2, use_scale=True, mode='embedded_gaussian', num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_hr18.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='CascadeEncoderDecoder', num_stages=2, pretrained='open-mmlab://msra/hrnetv2_w18', backbone=dict( type='HRNet', norm_cfg=norm_cfg, norm_eval=False, extra=dict( stage1=dict( num_modules=1, num_branches=1, block='BOTTLENECK', num_blocks=(4, ), num_channels=(64, )), stage2=dict( num_modules=1, num_branches=2, block='BASIC', num_blocks=(4, 4), num_channels=(18, 36)), stage3=dict( num_modules=4, num_branches=3, block='BASIC', num_blocks=(4, 4, 4), num_channels=(18, 36, 72)), stage4=dict( num_modules=3, num_branches=4, block='BASIC', num_blocks=(4, 4, 4, 4), num_channels=(18, 36, 72, 144)))), decode_head=[ dict( type='FCNHead', in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144]), in_index=(0, 1, 2, 3), input_transform='resize_concat', kernel_size=1, num_convs=1, concat_input=False, dropout_ratio=-1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), dict( type='OCRHead', in_channels=[18, 36, 72, 144], in_index=(0, 1, 2, 3), input_transform='resize_concat', channels=512, ocr_channels=256, dropout_ratio=-1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), ], # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/ocrnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='CascadeEncoderDecoder', num_stages=2, pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=[ dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), dict( type='OCRHead', in_channels=2048, in_index=3, channels=512, ocr_channels=256, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) ], # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pointrend_r50.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='CascadeEncoderDecoder', num_stages=2, pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 1, 1), strides=(1, 2, 2, 2), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=4), decode_head=[ dict( type='FPNHead', in_channels=[256, 256, 256, 256], in_index=[0, 1, 2, 3], feature_strides=[4, 8, 16, 32], channels=128, dropout_ratio=-1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), dict( type='PointHead', in_channels=[256], in_index=[0], channels=256, num_fcs=3, coarse_pred_each_layer=True, dropout_ratio=-1, num_classes=19, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) ], # model training and testing settings train_cfg=dict( num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), test_cfg=dict( mode='whole', subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/psanet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='PSAHead', in_channels=2048, in_index=3, channels=512, mask_size=(97, 97), psa_type='bi-direction', compact=False, shrink_factor=2, normalization_factor=1.0, psa_softmax=True, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_r50-d8.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='PSPHead', in_channels=2048, in_index=3, channels=512, pool_scales=(1, 2, 3, 6), dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/pspnet_unet_s5-d16.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='UNet', in_channels=3, base_channels=64, num_stages=5, strides=(1, 1, 1, 1, 1), enc_num_convs=(2, 2, 2, 2, 2), dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), with_cp=False, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='InterpConv'), norm_eval=False), decode_head=dict( type='PSPHead', in_channels=64, in_index=4, channels=16, pool_scales=(1, 2, 3, 6), dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=128, in_index=3, channels=64, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='slide', crop_size=256, stride=170)) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_cae.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='XCiT', patch_size=16, embed_dim=384, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=True, use_rel_pos_bias=False, ), decode_head=dict( type='UPerHead', in_channels=[384, 384, 384, 384], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=384, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/models/upernet_r50.py ================================================ # model settings norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 1, 1), strides=(1, 2, 2, 2), norm_cfg=norm_cfg, norm_eval=False, style='pytorch', contract_dilation=True), decode_head=dict( type='UPerHead', in_channels=[256, 512, 1024, 2048], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=1024, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_160k.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() # learning policy lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=160000) checkpoint_config = dict(by_epoch=False, interval=16000) evaluation = dict(interval=16000, metric='mIoU') ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_20k.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() # learning policy lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=20000) checkpoint_config = dict(by_epoch=False, interval=2000) evaluation = dict(interval=2000, metric='mIoU') ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_320k.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() # learning policy lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=320000) checkpoint_config = dict(by_epoch=False, interval=32000) evaluation = dict(interval=32000, metric='mIoU') ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_40k.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() # learning policy lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=40000) checkpoint_config = dict(by_epoch=False, interval=4000) evaluation = dict(interval=4000, metric='mIoU') ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/_base_/schedules/schedule_80k.py ================================================ # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() # learning policy lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=80000) checkpoint_config = dict(by_epoch=False, interval=8000) evaluation = dict(interval=8000, metric='mIoU') ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/beit/upernet_beit_base_12_512_slide_160k_ade20k_pt_4e-4.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='BEiT', img_size=512, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_sincos_pos_embed=False, use_rel_pos_bias=True, init_values=0.1, drop_path_rate=0.1, out_indices=[3, 5, 7, 11] ), decode_head=dict( in_channels=[768, 768, 768, 768], num_classes=150, channels=768, ), auxiliary_head=dict( in_channels=768, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) # AdamW optimizer, no weight decay for position embedding & layer norm in backbone # optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, # paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), # 'relative_position_bias_table': dict(decay_mult=0.), # 'norm': dict(decay_mult=0.)})) optimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) #img_norm_cfg = dict( # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) #crop_size = (512, 512) ## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=8000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_1e-4.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='CAE', img_size=512, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, drop_path_rate=0.1, out_indices=[3, 5, 7, 11] ), decode_head=dict( in_channels=[768, 768, 768, 768], num_classes=150, channels=768, ), auxiliary_head=dict( in_channels=768, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) optimizer = dict(_delete_=True, type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=8000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_2e-4.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='CAE', img_size=512, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, drop_path_rate=0.1, out_indices=[3, 5, 7, 11] ), decode_head=dict( in_channels=[768, 768, 768, 768], num_classes=150, channels=768, ), auxiliary_head=dict( in_channels=768, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) optimizer = dict(_delete_=True, type='AdamW', lr=2e-4, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=8000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_base_12_512_slide_160k_ade20k_pt_3e-4.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='CAE', img_size=512, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, drop_path_rate=0.1, out_indices=[3, 5, 7, 11] ), decode_head=dict( in_channels=[768, 768, 768, 768], num_classes=150, channels=768, ), auxiliary_head=dict( in_channels=768, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) optimizer = dict(_delete_=True, type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=8000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/cae/upernet/upernet_cae_large_24_512_slide_160k_ade20k_pt_decay095_4e-5_dp015.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_cae.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='CAE', img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, drop_path_rate=0.15, out_indices=[7, 11, 15, 23], ), decode_head=dict( in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024, ), auxiliary_head=dict( in_channels=1024, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) # AdamW optimizer, no weight decay for position embedding & layer norm in backbone # optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, # paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), # 'relative_position_bias_table': dict(decay_mult=0.), # 'norm': dict(decay_mult=0.)})) optimizer = dict(_delete_=True, type='AdamW', lr=4e-5, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) #img_norm_cfg = dict( # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) #crop_size = (512, 512) ## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=32000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/configs_local/mae/upernet_mae_large_12_512_slide_160k_ade20k_pt_4e-4.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, mmseg, setr, xcit and swin code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/fudan-zvg/SETR # https://github.com/facebookresearch/xcit/ # https://github.com/microsoft/Swin-Transformer # --------------------------------------------------------' _base_ = [ '../../_base_/models/upernet_beit.py', '../../_base_/datasets/ade20k.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_160k.py' ] crop_size = (512, 512) model = dict( backbone=dict( type='MAE', img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1, drop_path_rate=0.2, out_indices=[7, 11, 15, 23], ), decode_head=dict( in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024, ), auxiliary_head=dict( in_channels=1024, num_classes=150 ), test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) ) # AdamW optimizer, no weight decay for position embedding & layer norm in backbone # optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, # paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), # 'relative_position_bias_table': dict(decay_mult=0.), # 'norm': dict(decay_mult=0.)})) optimizer = dict(_delete_=True, type='AdamW', lr=4e-4, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.65)) lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data=dict(samples_per_gpu=2) #img_norm_cfg = dict( # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) #crop_size = (512, 512) ## test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) find_unused_parameters = True #test_pipeline = [ # dict(type='LoadImageFromFile'), # dict( # type='MultiScaleFlipAug', # img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], # flip=True, # transforms=[ # dict(type='SETR_Resize', keep_ratio=True, # crop_size=crop_size, setr_multi_scale=True), # dict(type='RandomFlip'), # dict(type='Normalize', **img_norm_cfg), # dict(type='ImageToTensor', keys=['img']), # dict(type='Collect', keys=['img']), # ]) #] #data = dict( # val=dict(pipeline=test_pipeline), # test=dict(pipeline=test_pipeline), # samples_per_gpu=2, #) runner = dict(type='IterBasedRunnerAmp') checkpoint_config = dict(by_epoch=False, interval=32000) # do not use mmdet version fp16 fp16 = None optimizer_config = dict( type="DistOptimizerHook", update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=True, ) ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/__init__.py ================================================ # -*- coding: utf-8 -*- from .checkpoint import load_checkpoint from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor from .resize_transform import SETR_Resize from .apex_runner.optimizer import DistOptimizerHook from .train_api import train_segmentor __all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/__init__.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. from .checkpoint import save_checkpoint from .apex_iter_based_runner import IterBasedRunnerAmp __all__ = [ 'save_checkpoint', 'IterBasedRunnerAmp', ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/apex_iter_based_runner.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. import os.path as osp import platform import shutil import torch from torch.optim import Optimizer import mmcv from mmcv.runner import RUNNERS, IterBasedRunner from .checkpoint import save_checkpoint try: import apex except: print('apex is not installed') @RUNNERS.register_module() class IterBasedRunnerAmp(IterBasedRunner): """Iteration-based Runner with AMP support. This runner train models iteration by iteration. """ def save_checkpoint(self, out_dir, filename_tmpl='iter_{}.pth', meta=None, save_optimizer=True, create_symlink=False): """Save checkpoint to file. Args: out_dir (str): Directory to save checkpoint files. filename_tmpl (str, optional): Checkpoint file template. Defaults to 'iter_{}.pth'. meta (dict, optional): Metadata to be saved in checkpoint. Defaults to None. save_optimizer (bool, optional): Whether save optimizer. Defaults to True. create_symlink (bool, optional): Whether create symlink to the latest checkpoint file. Defaults to True. """ if meta is None: meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) elif isinstance(meta, dict): meta.update(iter=self.iter + 1, epoch=self.epoch + 1) else: raise TypeError( f'meta should be a dict or None, but got {type(meta)}') if self.meta is not None: meta.update(self.meta) filename = filename_tmpl.format(self.iter + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False # if create_symlink: # dst_file = osp.join(out_dir, 'latest.pth') # if platform.system() != 'Windows': # mmcv.symlink(filename, dst_file) # else: # shutil.copy(filepath, dst_file) def resume(self, checkpoint, resume_optimizer=True, map_location='default'): if map_location == 'default': if torch.cuda.is_available(): device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] self._inner_iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}') if 'amp' in checkpoint: apex.amp.load_state_dict(checkpoint['amp']) self.logger.info('load amp state dict') self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/checkpoint.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. import os.path as osp import time from tempfile import TemporaryDirectory import torch from torch.optim import Optimizer import mmcv from mmcv.parallel import is_module_wrapper from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict try: import apex except: print('apex is not installed') def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 4 fields: ``meta``, ``state_dict`` and ``optimizer``, ``amp``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() # save amp state dict in the checkpoint checkpoint['amp'] = apex.amp.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush() ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/apex_runner/optimizer.py ================================================ from mmcv.runner import OptimizerHook, HOOKS try: import apex except: print('apex is not installed') @HOOKS.register_module() class DistOptimizerHook(OptimizerHook): """Optimizer hook for distributed training.""" def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): self.grad_clip = grad_clip self.coalesce = coalesce self.bucket_size_mb = bucket_size_mb self.update_interval = update_interval self.use_fp16 = use_fp16 def before_run(self, runner): runner.optimizer.zero_grad() def after_train_iter(self, runner): runner.outputs['loss'] /= self.update_interval if self.use_fp16: with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss: scaled_loss.backward() else: runner.outputs['loss'].backward() if self.every_n_iters(runner, self.update_interval): if self.grad_clip is not None: self.clip_grads(runner.model.parameters()) runner.optimizer.step() runner.optimizer.zero_grad() ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. import io import os import os.path as osp import pkgutil import time import warnings from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory import torch import torchvision from torch.optim import Optimizer from torch.utils import model_zoo from torch.nn import functional as F import mmcv from mmcv.fileio import FileClient from mmcv.fileio import load as load_file from mmcv.parallel import is_module_wrapper from mmcv.utils import mkdir_or_exist from mmcv.runner import get_dist_info from scipy import interpolate import numpy as np import math ENV_MMCV_HOME = 'MMCV_HOME' ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' DEFAULT_CACHE_DIR = '~/.cache' def _get_mmcv_home(): mmcv_home = os.path.expanduser( os.getenv( ENV_MMCV_HOME, os.path.join( os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) mkdir_or_exist(mmcv_home) return mmcv_home def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source ' f'state_dict: {", ".join(unexpected_keys)}\n') if missing_keys: err_msg.append( f'missing keys in source state_dict: {", ".join(missing_keys)}\n') rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) def load_url_dist(url, model_dir=None, map_location="cpu"): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) return checkpoint def load_pavimodel_dist(model_path, map_location=None): """In distributed setting, this function only download checkpoint at local rank 0.""" try: from pavi import modelcloud except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load(downloaded_file, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load( downloaded_file, map_location=map_location) return checkpoint def load_fileclient_dist(filename, backend, map_location): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) allowed_backends = ['ceph'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') if rank == 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) return checkpoint def get_torchvision_models(): model_urls = dict() for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue _zoo = import_module(f'torchvision.models.{name}') if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) return model_urls def get_external_models(): mmcv_home = _get_mmcv_home() default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') default_urls = load_file(default_json_path) assert isinstance(default_urls, dict) external_json_path = osp.join(mmcv_home, 'open_mmlab.json') if osp.exists(external_json_path): external_urls = load_file(external_json_path) assert isinstance(external_urls, dict) default_urls.update(external_urls) return default_urls def get_mmcls_models(): mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') mmcls_urls = load_file(mmcls_json_path) return mmcls_urls def get_deprecated_model_names(): deprecate_json_path = osp.join(mmcv.__path__[0], 'model_zoo/deprecated.json') deprecate_urls = load_file(deprecate_json_path) assert isinstance(deprecate_urls, dict) return deprecate_urls def _process_mmcls_checkpoint(checkpoint): state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('backbone.'): new_state_dict[k[9:]] = v new_checkpoint = dict(state_dict=new_state_dict) return new_checkpoint def _load_checkpoint(filename, map_location=None): """Load checkpoint from somewhere (modelzoo, file, url). Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str | None): Same as :func:`torch.load`. Default: None. Returns: dict | OrderedDict: The loaded checkpoint. It can be either an OrderedDict storing model weights or a dict containing other information, which depends on the checkpoint. """ if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_urls = get_external_models() model_name = filename[13:] deprecated_urls = get_deprecated_model_names() if model_name in deprecated_urls: warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' f'of open-mmlab://{deprecated_urls[model_name]}') model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url if model_url.startswith(('http://', 'https://')): checkpoint = load_url_dist(model_url) else: filename = osp.join(_get_mmcv_home(), model_url) if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) elif filename.startswith('mmcls://'): model_urls = get_mmcls_models() model_name = filename[8:] checkpoint = load_url_dist(model_urls[model_name]) checkpoint = _process_mmcls_checkpoint(checkpoint) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) elif filename.startswith('pavi://'): model_path = filename[7:] checkpoint = load_pavimodel_dist(model_path, map_location=map_location) elif filename.startswith('s3://'): checkpoint = load_fileclient_dist( filename, backend='ceph', map_location=map_location) else: if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) return checkpoint def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0, warmup_steps=-1): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_steps > 0: warmup_iters = warmup_steps print("Set warmup steps = %d" % warmup_iters) if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = np.array( [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def load_checkpoint(model, filename, map_location='cpu', strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] elif 'module' in checkpoint: state_dict = checkpoint['module'] else: state_dict = checkpoint # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # for MoBY, load model of online branch if sorted(list(state_dict.keys()))[0].startswith('encoder'): state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} all_keys = list(state_dict.keys()) print("origin keys:", len(all_keys), all_keys) if all_keys[-1].startswith('encoder_to_decoder') or all_keys[-1].startswith('decoder'): # NOTE: remove all decoder keys all_keys = [key for key in all_keys if key.startswith('encoder.')] print("all keys:", all_keys) for key in all_keys: new_key = key.replace('encoder.','') # print("new_key:", new_key) state_dict[new_key] = state_dict[key] state_dict.pop(key) for key in list(state_dict.keys()): if key.startswith('decoder.'): # print("key:", key) state_dict.pop(key) # NOTE: replace norm with fc_norm for key in list(state_dict.keys()): # print("new key:", key) if key.startswith('norm.'): new_key = key.replace('norm.','fc_norm.') state_dict[new_key] = state_dict[key] state_dict.pop(key) print("new keys:", len(state_dict), state_dict.keys()) # reshape absolute position embedding for Swin if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = model.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H*W: logger.warning("Error in loading absolute_pos_embed, pass") else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) rank, _ = get_dist_info() if "rel_pos_bias.relative_position_bias_table" in state_dict: if rank == 0: print("Expand the shared relative position embedding to each layers. ") num_layers = model.get_num_layers() rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() state_dict.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(state_dict.keys()) # for moco for key in all_keys: if 'base_encoder.' in key: new_key = key.replace('base_encoder.', '') state_dict[new_key] = state_dict[key] state_dict.pop(key) if 'momentum_encoder' in key: state_dict.pop(key) # for ibot for key in all_keys: if 'module.backbone.' in key: new_key = key.replace('module.backbone.', '') state_dict[new_key] = state_dict[key] state_dict.pop(key) elif 'backbone.' in key: new_key = key.replace('backbone.', '') state_dict[new_key] = state_dict[key] state_dict.pop(key) for key in all_keys: if "relative_position_index" in key: state_dict.pop(key) if "relative_position_bias_table" in key: rel_pos_bias = state_dict[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: if rank == 0: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.13492: # q = 1.13492 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) if rank == 0: print("x = {}".format(x)) print("dx = {}".format(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) state_dict[key] = new_rel_pos_bias if 'pos_embed' in state_dict: #and model.use_abs_pos_emb: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: if rank == 0: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed # interpolate position bias table if needed relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = model.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {table_key}, pass") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint def weights_to_cpu(state_dict): """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() return state_dict_cpu def _save_to_state_dict(module, destination, prefix, keep_vars): """Saves module state to `destination` dictionary. This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. Args: module (nn.Module): The module to generate state_dict. destination (dict): A dict where state will be stored. prefix (str): The prefix for parameters and buffers used in this module. """ for name, param in module._parameters.items(): if param is not None: destination[prefix + name] = param if keep_vars else param.detach() for name, buf in module._buffers.items(): # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d if buf is not None: destination[prefix + name] = buf if keep_vars else buf.detach() def get_state_dict(module, destination=None, prefix='', keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel module in case that the model has a complicated structure, e.g., nn.Module(nn.Module(DDP)). Args: module (nn.Module): The module to generate state_dict. destination (OrderedDict): Returned dict for the state of the module. prefix (str): Prefix of the key. keep_vars (bool): Whether to keep the variable property of the parameters. Default: False. Returns: dict: A dictionary containing a whole state of the module. """ # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module # below is the same as torch.nn.Module.state_dict() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict( version=module._version) _save_to_state_dict(module, destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: get_state_dict( child, destination, prefix + name + '.', keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush() ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/checkpoint_beit.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. import io import os import os.path as osp import pkgutil import time import warnings from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory import torch import torchvision from torch.optim import Optimizer from torch.utils import model_zoo from torch.nn import functional as F import mmcv from mmcv.fileio import FileClient from mmcv.fileio import load as load_file from mmcv.parallel import is_module_wrapper from mmcv.utils import mkdir_or_exist from mmcv.runner import get_dist_info from scipy import interpolate import numpy as np import math ENV_MMCV_HOME = 'MMCV_HOME' ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' DEFAULT_CACHE_DIR = '~/.cache' def _get_mmcv_home(): mmcv_home = os.path.expanduser( os.getenv( ENV_MMCV_HOME, os.path.join( os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) mkdir_or_exist(mmcv_home) return mmcv_home def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source ' f'state_dict: {", ".join(unexpected_keys)}\n') if missing_keys: err_msg.append( f'missing keys in source state_dict: {", ".join(missing_keys)}\n') rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) def load_url_dist(url, model_dir=None, map_location="cpu"): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location) return checkpoint def load_pavimodel_dist(model_path, map_location=None): """In distributed setting, this function only download checkpoint at local rank 0.""" try: from pavi import modelcloud except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load(downloaded_file, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) checkpoint = torch.load( downloaded_file, map_location=map_location) return checkpoint def load_fileclient_dist(filename, backend, map_location): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) allowed_backends = ['ceph'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') if rank == 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) return checkpoint def get_torchvision_models(): model_urls = dict() for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue _zoo = import_module(f'torchvision.models.{name}') if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) return model_urls def get_external_models(): mmcv_home = _get_mmcv_home() default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') default_urls = load_file(default_json_path) assert isinstance(default_urls, dict) external_json_path = osp.join(mmcv_home, 'open_mmlab.json') if osp.exists(external_json_path): external_urls = load_file(external_json_path) assert isinstance(external_urls, dict) default_urls.update(external_urls) return default_urls def get_mmcls_models(): mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') mmcls_urls = load_file(mmcls_json_path) return mmcls_urls def get_deprecated_model_names(): deprecate_json_path = osp.join(mmcv.__path__[0], 'model_zoo/deprecated.json') deprecate_urls = load_file(deprecate_json_path) assert isinstance(deprecate_urls, dict) return deprecate_urls def _process_mmcls_checkpoint(checkpoint): state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('backbone.'): new_state_dict[k[9:]] = v new_checkpoint = dict(state_dict=new_state_dict) return new_checkpoint def _load_checkpoint(filename, map_location=None): """Load checkpoint from somewhere (modelzoo, file, url). Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str | None): Same as :func:`torch.load`. Default: None. Returns: dict | OrderedDict: The loaded checkpoint. It can be either an OrderedDict storing model weights or a dict containing other information, which depends on the checkpoint. """ if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_urls = get_external_models() model_name = filename[13:] deprecated_urls = get_deprecated_model_names() if model_name in deprecated_urls: warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' f'of open-mmlab://{deprecated_urls[model_name]}') model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url if model_url.startswith(('http://', 'https://')): checkpoint = load_url_dist(model_url) else: filename = osp.join(_get_mmcv_home(), model_url) if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) elif filename.startswith('mmcls://'): model_urls = get_mmcls_models() model_name = filename[8:] checkpoint = load_url_dist(model_urls[model_name]) checkpoint = _process_mmcls_checkpoint(checkpoint) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) elif filename.startswith('pavi://'): model_path = filename[7:] checkpoint = load_pavimodel_dist(model_path, map_location=map_location) elif filename.startswith('s3://'): checkpoint = load_fileclient_dist( filename, backend='ceph', map_location=map_location) else: if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) return checkpoint def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0, warmup_steps=-1): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_steps > 0: warmup_iters = warmup_steps print("Set warmup steps = %d" % warmup_iters) if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = np.array( [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def load_checkpoint(model, filename, map_location='cpu', strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] elif 'module' in checkpoint: state_dict = checkpoint['module'] else: state_dict = checkpoint # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # for MoBY, load model of online branch if sorted(list(state_dict.keys()))[0].startswith('encoder'): state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} # reshape absolute position embedding for Swin if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = model.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H*W: logger.warning("Error in loading absolute_pos_embed, pass") else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) rank, _ = get_dist_info() if "rel_pos_bias.relative_position_bias_table" in state_dict: if rank == 0: print("Expand the shared relative position embedding to each layers. ") num_layers = model.get_num_layers() rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() state_dict.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(state_dict.keys()) all_keys = sorted(all_keys) print("origin keys:", len(all_keys), all_keys) if all_keys[-2].startswith('encoder_to_decoder'): # NOTE: remove all decoder keys all_keys = [key for key in all_keys if key.startswith('encoder.')] print("all keys:", all_keys) for key in all_keys: new_key = key.replace('encoder.','') # print("new_key:", new_key) state_dict[new_key] = state_dict[key] state_dict.pop(key) for key in list(state_dict.keys()): if key.startswith('decoder.'): # print("key:", key) state_dict.pop(key) # NOTE: replace norm with fc_norm for key in list(state_dict.keys()): # print("new key:", key) if key.startswith('norm.'): new_key = key.replace('norm.','fc_norm.') state_dict[new_key] = state_dict[key] state_dict.pop(key) print("new keys:", len(state_dict), state_dict.keys()) for key in all_keys: if "relative_position_index" in key: state_dict.pop(key) if "relative_position_bias_table" in key: rel_pos_bias = state_dict[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: if rank == 0: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.13492: # q = 1.13492 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) if rank == 0: print("x = {}".format(x)) print("dx = {}".format(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) state_dict[key] = new_rel_pos_bias if 'pos_embed' in state_dict: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: if rank == 0: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed # interpolate position bias table if needed relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = model.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {table_key}, pass") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint def weights_to_cpu(state_dict): """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() return state_dict_cpu def _save_to_state_dict(module, destination, prefix, keep_vars): """Saves module state to `destination` dictionary. This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. Args: module (nn.Module): The module to generate state_dict. destination (dict): A dict where state will be stored. prefix (str): The prefix for parameters and buffers used in this module. """ for name, param in module._parameters.items(): if param is not None: destination[prefix + name] = param if keep_vars else param.detach() for name, buf in module._buffers.items(): # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d if buf is not None: destination[prefix + name] = buf if keep_vars else buf.detach() def get_state_dict(module, destination=None, prefix='', keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel module in case that the model has a complicated structure, e.g., nn.Module(nn.Module(DDP)). Args: module (nn.Module): The module to generate state_dict. destination (OrderedDict): Returned dict for the state of the module. prefix (str): Prefix of the key. keep_vars (bool): Whether to keep the variable property of the parameters. Default: False. Returns: dict: A dictionary containing a whole state of the module. """ # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module # below is the same as torch.nn.Module.state_dict() if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict( version=module._version) _save_to_state_dict(module, destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: get_state_dict( child, destination, prefix + name + '.', keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush() ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py ================================================ import json from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor from mmcv.runner import get_dist_info def get_num_layer_for_vit(var_name, num_max_layer): if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): return 0 elif var_name.startswith("backbone.patch_embed"): return 0 elif var_name.startswith("backbone.blocks"): layer_id = int(var_name.split('.')[2]) return layer_id + 1 else: return num_max_layer - 1 @OPTIMIZER_BUILDERS.register_module() class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): def add_params(self, params, module, prefix='', is_dcn_module=None): """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. Args: params (list[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. prefix (str): The prefix of the module is_dcn_module (int|float|None): If the current module is a submodule of DCN, `is_dcn_module` will be passed to control conv_offset layer's learning rate. Defaults to None. """ parameter_groups = {} print(self.paramwise_cfg) num_layers = self.paramwise_cfg.get('num_layers') + 2 layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers)) weight_decay = self.base_wd for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'): group_name = "no_decay" this_weight_decay = 0. else: group_name = "decay" this_weight_decay = weight_decay layer_id = get_num_layer_for_vit(name, num_layers) group_name = "layer_%d_%s" % (layer_id, group_name) if group_name not in parameter_groups: scale = layer_decay_rate ** (num_layers - layer_id - 1) parameter_groups[group_name] = { "weight_decay": this_weight_decay, "params": [], "param_names": [], "lr_scale": scale, "group_name": group_name, "lr": scale * self.base_lr, } parameter_groups[group_name]["params"].append(param) parameter_groups[group_name]["param_names"].append(name) rank, _ = get_dist_info() if rank == 0: to_display = {} for key in parameter_groups: to_display[key] = { "param_names": parameter_groups[key]["param_names"], "lr_scale": parameter_groups[key]["lr_scale"], "lr": parameter_groups[key]["lr"], "weight_decay": parameter_groups[key]["weight_decay"], } print("Param groups = %s" % json.dumps(to_display, indent=2)) # state_dict = module.state_dict() # for group_name in parameter_groups: # group = parameter_groups[group_name] # for name in group["param_names"]: # group["params"].append(state_dict[name]) params.extend(parameter_groups.values()) ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/resize_transform.py ================================================ import mmcv import numpy as np from mmseg.datasets.builder import PIPELINES @PIPELINES.register_module() class SETR_Resize(object): """Resize images & seg. This transform resizes the input image to some scale. If the input dict contains the key "scale", then the scale in the input dict is used, otherwise the specified scale in the init method is used. ``img_scale`` can either be a tuple (single-scale) or a list of tuple (multi-scale). There are 3 multiscale modes: - ``ratio_range is not None``: randomly sample a ratio from the ratio range and multiply it with the image scale. - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a scale from the a range. - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a scale from multiple scales. Args: img_scale (tuple or list[tuple]): Images scales for resizing. multiscale_mode (str): Either "range" or "value". ratio_range (tuple[float]): (min_ratio, max_ratio) keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. """ def __init__(self, img_scale=None, multiscale_mode='range', ratio_range=None, keep_ratio=True, crop_size=None, setr_multi_scale=False): if img_scale is None: self.img_scale = None else: if isinstance(img_scale, list): self.img_scale = img_scale else: self.img_scale = [img_scale] # assert mmcv.is_list_of(self.img_scale, tuple) if ratio_range is not None: # mode 1: given a scale and a range of image ratio assert len(self.img_scale) == 1 else: # mode 2: given multiple scales or a range of scales assert multiscale_mode in ['value', 'range'] self.multiscale_mode = multiscale_mode self.ratio_range = ratio_range self.keep_ratio = keep_ratio self.crop_size = crop_size self.setr_multi_scale = setr_multi_scale @staticmethod def random_select(img_scales): """Randomly select an img_scale from given candidates. Args: img_scales (list[tuple]): Images scales for selection. Returns: (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, where ``img_scale`` is the selected image scale and ``scale_idx`` is the selected index in the given candidates. """ assert mmcv.is_list_of(img_scales, tuple) scale_idx = np.random.randint(len(img_scales)) img_scale = img_scales[scale_idx] return img_scale, scale_idx @staticmethod def random_sample(img_scales): """Randomly sample an img_scale when ``multiscale_mode=='range'``. Args: img_scales (list[tuple]): Images scale range for sampling. There must be two tuples in img_scales, which specify the lower and uper bound of image scales. Returns: (tuple, None): Returns a tuple ``(img_scale, None)``, where ``img_scale`` is sampled scale and None is just a placeholder to be consistent with :func:`random_select`. """ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 img_scale_long = [max(s) for s in img_scales] img_scale_short = [min(s) for s in img_scales] long_edge = np.random.randint( min(img_scale_long), max(img_scale_long) + 1) short_edge = np.random.randint( min(img_scale_short), max(img_scale_short) + 1) img_scale = (long_edge, short_edge) return img_scale, None @staticmethod def random_sample_ratio(img_scale, ratio_range): """Randomly sample an img_scale when ``ratio_range`` is specified. A ratio will be randomly sampled from the range specified by ``ratio_range``. Then it would be multiplied with ``img_scale`` to generate sampled scale. Args: img_scale (tuple): Images scale base to multiply with ratio. ratio_range (tuple[float]): The minimum and maximum ratio to scale the ``img_scale``. Returns: (tuple, None): Returns a tuple ``(scale, None)``, where ``scale`` is sampled ratio multiplied with ``img_scale`` and None is just a placeholder to be consistent with :func:`random_select`. """ assert isinstance(img_scale, tuple) and len(img_scale) == 2 min_ratio, max_ratio = ratio_range assert min_ratio <= max_ratio ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) return scale, None def _random_scale(self, results): """Randomly sample an img_scale according to ``ratio_range`` and ``multiscale_mode``. If ``ratio_range`` is specified, a ratio will be sampled and be multiplied with ``img_scale``. If multiple scales are specified by ``img_scale``, a scale will be sampled according to ``multiscale_mode``. Otherwise, single scale will be used. Args: results (dict): Result dict from :obj:`dataset`. Returns: dict: Two new keys 'scale` and 'scale_idx` are added into ``results``, which would be used by subsequent pipelines. """ if self.ratio_range is not None: scale, scale_idx = self.random_sample_ratio( self.img_scale[0], self.ratio_range) elif len(self.img_scale) == 1: scale, scale_idx = self.img_scale[0], 0 elif self.multiscale_mode == 'range': scale, scale_idx = self.random_sample(self.img_scale) elif self.multiscale_mode == 'value': scale, scale_idx = self.random_select(self.img_scale) else: raise NotImplementedError results['scale'] = scale results['scale_idx'] = scale_idx def _resize_img(self, results): """Resize images with ``results['scale']``.""" if self.keep_ratio: if self.setr_multi_scale: if min(results['scale']) < self.crop_size[0]: new_short = self.crop_size[0] else: new_short = min(results['scale']) h, w = results['img'].shape[:2] if h > w: new_h, new_w = new_short * h / w, new_short else: new_h, new_w = new_short, new_short * w / h results['scale'] = (new_h, new_w) img, scale_factor = mmcv.imrescale( results['img'], results['scale'], return_scale=True) # the w_scale and h_scale has minor difference # a real fix should be done in the mmcv.imrescale in the future new_h, new_w = img.shape[:2] h, w = results['img'].shape[:2] w_scale = new_w / w h_scale = new_h / h else: img, w_scale, h_scale = mmcv.imresize( results['img'], results['scale'], return_scale=True) scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) results['img'] = img results['img_shape'] = img.shape results['pad_shape'] = img.shape # in case that there is no padding results['scale_factor'] = scale_factor results['keep_ratio'] = self.keep_ratio def _resize_seg(self, results): """Resize semantic segmentation map with ``results['scale']``.""" for key in results.get('seg_fields', []): if self.keep_ratio: gt_seg = mmcv.imrescale( results[key], results['scale'], interpolation='nearest') else: gt_seg = mmcv.imresize( results[key], results['scale'], interpolation='nearest') results['gt_semantic_seg'] = gt_seg def __call__(self, results): """Call function to resize images, bounding boxes, masks, semantic segmentation map. Args: results (dict): Result dict from loading pipeline. Returns: dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', 'keep_ratio' keys are added into result dict. """ if 'scale' not in results: self._random_scale(results) self._resize_img(results) self._resize_seg(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(img_scale={self.img_scale}, ' f'multiscale_mode={self.multiscale_mode}, ' f'ratio_range={self.ratio_range}, ' f'keep_ratio={self.keep_ratio})') return repr_str ================================================ FILE: downstream_tasks/semantic_segmentation/mmcv_custom/train_api.py ================================================ import random import warnings import numpy as np import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import build_optimizer, build_runner from mmseg.core import DistEvalHook, EvalHook from mmseg.datasets import build_dataloader, build_dataset from mmseg.utils import get_root_logger try: import apex except: print('apex is not installed') def set_random_seed(seed, deterministic=False): """Set random seed. Args: seed (int): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Default: False. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train_segmentor(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): """Launch segmentor training.""" logger = get_root_logger(cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, # cfg.gpus will be ignored if distributed len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, drop_last=True) for ds in dataset ] # build optimizer optimizer = build_optimizer(model, cfg.optimizer) # use apex fp16 optimizer if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook": if cfg.optimizer_config.get("use_fp16", False): model, optimizer = apex.amp.initialize( model.cuda(), optimizer, opt_level="O1") for m in model.modules(): if hasattr(m, "fp16_enabled"): m.fp16_enabled = True # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) if cfg.get('runner') is None: cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} warnings.warn( 'config is now expected to have a `runner` section, ' 'please set `runner` in your config.', UserWarning) runner = build_runner( cfg.runner, default_args=dict( model=model, batch_processor=None, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta)) # register hooks runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp # register eval hooks if validate: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = 'IterBasedRunner' not in cfg.runner['type'] eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/__init__.py ================================================ import mmcv from .version import __version__, version_info MMCV_MIN = '1.1.4' MMCV_MAX = '1.3.0' def digit_version(version_str): digit_version = [] for x in version_str.split('.'): if x.isdigit(): digit_version.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') digit_version.append(int(patch_version[0]) - 1) digit_version.append(int(patch_version[1])) return digit_version mmcv_min_version = digit_version(MMCV_MIN) mmcv_max_version = digit_version(MMCV_MAX) mmcv_version = digit_version(mmcv.__version__) assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \ f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' __all__ = ['__version__', 'version_info'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/apis/__init__.py ================================================ from .inference import inference_segmentor, init_segmentor, show_result_pyplot from .test import multi_gpu_test, single_gpu_test from .train import get_root_logger, set_random_seed, train_segmentor __all__ = [ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/apis/inference.py ================================================ import matplotlib.pyplot as plt import mmcv import torch from mmcv.parallel import collate, scatter from mmcv.runner import load_checkpoint from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor def init_segmentor(config, checkpoint=None, device='cuda:0'): """Initialize a segmentor from config file. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. device (str, optional) CPU/CUDA device option. Default 'cuda:0'. Use 'cpu' for loading model on CPU. Returns: nn.Module: The constructed segmentor. """ if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' 'but got {}'.format(type(config))) config.model.pretrained = None config.model.train_cfg = None model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') model.CLASSES = checkpoint['meta']['CLASSES'] model.PALETTE = checkpoint['meta']['PALETTE'] model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model class LoadImage: """A simple pipeline to load image.""" def __call__(self, results): """Call function to load images into results. Args: results (dict): A result dict contains the file name of the image to be read. Returns: dict: ``results`` will be returned containing loaded image. """ if isinstance(results['img'], str): results['filename'] = results['img'] results['ori_filename'] = results['img'] else: results['filename'] = None results['ori_filename'] = None img = mmcv.imread(results['img']) results['img'] = img results['img_shape'] = img.shape results['ori_shape'] = img.shape return results def inference_segmentor(model, img): """Inference image(s) with the segmentor. Args: model (nn.Module): The loaded segmentor. imgs (str/ndarray or list[str/ndarray]): Either image files or loaded images. Returns: (list[Tensor]): The segmentation result. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: data['img_metas'] = [i.data[0] for i in data['img_metas']] # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)): """Visualize the segmentation results on the image. Args: model (nn.Module): The loaded segmentor. img (str or np.ndarray): Image filename or loaded image. result (list): The segmentation result. palette (list[list[int]]] | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None fig_size (tuple): Figure size of the pyplot figure. """ if hasattr(model, 'module'): model = model.module img = model.show_result(img, result, palette=palette, show=False) plt.figure(figsize=fig_size) plt.imshow(mmcv.bgr2rgb(img)) plt.show() ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/apis/test.py ================================================ import os.path as osp import pickle import shutil import tempfile import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.image import tensor2imgs from mmcv.runner import get_dist_info def np2tmp(array, temp_file_name=None): """Save ndarray to local numpy file. Args: array (ndarray): Ndarray to save. temp_file_name (str): Numpy file name. If 'temp_file_name=None', this function will generate a file name with tempfile.NamedTemporaryFile to save ndarray. Default: None. Returns: str: The numpy file name. """ if temp_file_name is None: temp_file_name = tempfile.NamedTemporaryFile( suffix='.npy', delete=False).name np.save(temp_file_name, array) return temp_file_name def single_gpu_test(model, data_loader, show=False, out_dir=None, efficient_test=False): """Test with single GPU. Args: model (nn.Module): Model to be tested. data_loader (utils.data.Dataloader): Pytorch data loader. show (bool): Whether show results during infernece. Default: False. out_dir (str, optional): If specified, the results will be dumped into the directory to save output results. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Default: False. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, **data) if show or out_dir: img_tensor = data['img'][0] img_metas = data['img_metas'][0].data[0] imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) assert len(imgs) == len(img_metas) for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] img_show = img[:h, :w, :] ori_h, ori_w = img_meta['ori_shape'][:-1] img_show = mmcv.imresize(img_show, (ori_w, ori_h)) if out_dir: out_file = osp.join(out_dir, img_meta['ori_filename']) else: out_file = None model.module.show_result( img_show, result, palette=dataset.PALETTE, show=show, out_file=out_file) if isinstance(result, list): if efficient_test: result = [np2tmp(_) for _ in result] results.extend(result) else: if efficient_test: result = np2tmp(result) results.append(result) batch_size = data['img'][0].size(0) for _ in range(batch_size): prog_bar.update() return results def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, efficient_test=False): """Test model with multiple gpus. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' it encodes results to gpu tensors and use gpu communication for results collection. On cpu mode it saves the results on different gpus to 'tmpdir' and collects them by the rank 0 worker. Args: model (nn.Module): Model to be tested. data_loader (utils.data.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. gpu_collect (bool): Option to use either gpu or cpu to collect results. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Default: False. Returns: list: The prediction results. """ model.eval() results = [] dataset = data_loader.dataset rank, world_size = get_dist_info() if rank == 0: prog_bar = mmcv.ProgressBar(len(dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) if isinstance(result, list): if efficient_test: result = [np2tmp(_) for _ in result] results.extend(result) else: if efficient_test: result = np2tmp(result) results.append(result) if rank == 0: batch_size = data['img'][0].size(0) for _ in range(batch_size * world_size): prog_bar.update() # collect results from all ranks if gpu_collect: results = collect_results_gpu(results, len(dataset)) else: results = collect_results_cpu(results, len(dataset), tmpdir) return results def collect_results_cpu(result_part, size, tmpdir=None): """Collect results with CPU.""" rank, world_size = get_dist_info() # create a tmp dir if it is not specified if tmpdir is None: MAX_LEN = 512 # 32 is whitespace dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda') if rank == 0: tmpdir = tempfile.mkdtemp() tmpdir = torch.tensor( bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') dir_tensor[:len(tmpdir)] = tmpdir dist.broadcast(dir_tensor, 0) tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() else: mmcv.mkdir_or_exist(tmpdir) # dump the part result to the dir mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) dist.barrier() # collect all parts if rank != 0: return None else: # load results of all parts from tmp dir part_list = [] for i in range(world_size): part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) part_list.append(mmcv.load(part_file)) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] # remove tmp dir shutil.rmtree(tmpdir) return ordered_results def collect_results_gpu(result_part, size): """Collect results with GPU.""" rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') # gather all result part tensor shape shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) ] # gather all result part dist.all_gather(part_recv_list, part_send) if rank == 0: part_list = [] for recv, shape in zip(part_recv_list, shape_list): part_list.append( pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/apis/train.py ================================================ import random import warnings import numpy as np import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import build_optimizer, build_runner from mmseg.core import DistEvalHook, EvalHook from mmseg.datasets import build_dataloader, build_dataset from mmseg.utils import get_root_logger def set_random_seed(seed, deterministic=False): """Set random seed. Args: seed (int): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Default: False. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train_segmentor(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): """Launch segmentor training.""" logger = get_root_logger(cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, # cfg.gpus will be ignored if distributed len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, drop_last=True) for ds in dataset ] # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) # build runner optimizer = build_optimizer(model, cfg.optimizer) if cfg.get('runner') is None: cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} warnings.warn( 'config is now expected to have a `runner` section, ' 'please set `runner` in your config.', UserWarning) runner = build_runner( cfg.runner, default_args=dict( model=model, batch_processor=None, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta)) # register hooks runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp # register eval hooks if validate: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/__init__.py ================================================ from .evaluation import * # noqa: F401, F403 from .seg import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/__init__.py ================================================ from .class_names import get_classes, get_palette from .eval_hooks import DistEvalHook, EvalHook from .metrics import eval_metrics, mean_dice, mean_iou __all__ = [ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', 'get_classes', 'get_palette' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/class_names.py ================================================ import mmcv def cityscapes_classes(): """Cityscapes class names for external use.""" return [ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle' ] def ade_classes(): """ADE20K class names for external use.""" return [ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag' ] def voc_classes(): """Pascal VOC class names for external use.""" return [ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] def cityscapes_palette(): """Cityscapes palette for external use.""" return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]] def ade_palette(): """ADE20K palette for external use.""" return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] def voc_palette(): """Pascal VOC palette for external use.""" return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] dataset_aliases = { 'cityscapes': ['cityscapes'], 'ade': ['ade', 'ade20k'], 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'] } def get_classes(dataset): """Get class names of a dataset.""" alias2name = {} for name, aliases in dataset_aliases.items(): for alias in aliases: alias2name[alias] = name if mmcv.is_str(dataset): if dataset in alias2name: labels = eval(alias2name[dataset] + '_classes()') else: raise ValueError(f'Unrecognized dataset: {dataset}') else: raise TypeError(f'dataset must a str, but got {type(dataset)}') return labels def get_palette(dataset): """Get class palette (RGB) of a dataset.""" alias2name = {} for name, aliases in dataset_aliases.items(): for alias in aliases: alias2name[alias] = name if mmcv.is_str(dataset): if dataset in alias2name: labels = eval(alias2name[dataset] + '_palette()') else: raise ValueError(f'Unrecognized dataset: {dataset}') else: raise TypeError(f'dataset must a str, but got {type(dataset)}') return labels ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/eval_hooks.py ================================================ import os.path as osp from mmcv.runner import Hook from torch.utils.data import DataLoader class EvalHook(Hook): """Evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. """ def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError('dataloader must be a pytorch DataLoader, but got ' f'{type(dataloader)}') self.dataloader = dataloader self.interval = interval self.by_epoch = by_epoch self.eval_kwargs = eval_kwargs def after_train_iter(self, runner): """After train epoch hook.""" if self.by_epoch or not self.every_n_iters(runner, self.interval): return from mmseg.apis import single_gpu_test runner.log_buffer.clear() results = single_gpu_test(runner.model, self.dataloader, show=False) self.evaluate(runner, results) def after_train_epoch(self, runner): """After train epoch hook.""" if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return from mmseg.apis import single_gpu_test runner.log_buffer.clear() results = single_gpu_test(runner.model, self.dataloader, show=False) self.evaluate(runner, results) def evaluate(self, runner, results): """Call evaluate function of dataset.""" eval_res = self.dataloader.dataset.evaluate( results, logger=runner.logger, **self.eval_kwargs) for name, val in eval_res.items(): runner.log_buffer.output[name] = val runner.log_buffer.ready = True class DistEvalHook(EvalHook): """Distributed evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. """ def __init__(self, dataloader, interval=1, gpu_collect=False, by_epoch=False, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError( 'dataloader must be a pytorch DataLoader, but got {}'.format( type(dataloader))) self.dataloader = dataloader self.interval = interval self.gpu_collect = gpu_collect self.by_epoch = by_epoch self.eval_kwargs = eval_kwargs def after_train_iter(self, runner): """After train epoch hook.""" if self.by_epoch or not self.every_n_iters(runner, self.interval): return from mmseg.apis import multi_gpu_test runner.log_buffer.clear() results = multi_gpu_test( runner.model, self.dataloader, tmpdir=osp.join(runner.work_dir, '.eval_hook'), gpu_collect=self.gpu_collect) if runner.rank == 0: print('\n') self.evaluate(runner, results) def after_train_epoch(self, runner): """After train epoch hook.""" if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return from mmseg.apis import multi_gpu_test runner.log_buffer.clear() results = multi_gpu_test( runner.model, self.dataloader, tmpdir=osp.join(runner.work_dir, '.eval_hook'), gpu_collect=self.gpu_collect) if runner.rank == 0: print('\n') self.evaluate(runner, results) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/evaluation/metrics.py ================================================ import mmcv import numpy as np def intersect_and_union(pred_label, label, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): """Calculate intersection and Union. Args: pred_label (ndarray): Prediction segmentation map. label (ndarray): Ground truth segmentation map. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. label_map (dict): Mapping old labels to new labels. The parameter will work only when label is str. Default: dict(). reduce_zero_label (bool): Wether ignore zero label. The parameter will work only when label is str. Default: False. Returns: ndarray: The intersection of prediction and ground truth histogram on all classes. ndarray: The union of prediction and ground truth histogram on all classes. ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ if isinstance(pred_label, str): pred_label = np.load(pred_label) if isinstance(label, str): label = mmcv.imread(label, flag='unchanged', backend='pillow') # modify if custom classes if label_map is not None: for old_id, new_id in label_map.items(): label[label == old_id] = new_id if reduce_zero_label: # avoid using underflow conversion label[label == 0] = 255 label = label - 1 label[label == 254] = 255 mask = (label != ignore_index) pred_label = pred_label[mask] label = label[mask] intersect = pred_label[pred_label == label] area_intersect, _ = np.histogram( intersect, bins=np.arange(num_classes + 1)) area_pred_label, _ = np.histogram( pred_label, bins=np.arange(num_classes + 1)) area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) area_union = area_pred_label + area_label - area_intersect return area_intersect, area_union, area_pred_label, area_label def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): """Calculate Total Intersection and Union. Args: results (list[ndarray]): List of prediction segmentation maps. gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. label_map (dict): Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: ndarray: The intersection of prediction and ground truth histogram on all classes. ndarray: The union of prediction and ground truth histogram on all classes. ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ num_imgs = len(results) assert len(gt_seg_maps) == num_imgs total_area_intersect = np.zeros((num_classes, ), dtype=np.float) total_area_union = np.zeros((num_classes, ), dtype=np.float) total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) total_area_label = np.zeros((num_classes, ), dtype=np.float) for i in range(num_imgs): area_intersect, area_union, area_pred_label, area_label = \ intersect_and_union(results[i], gt_seg_maps[i], num_classes, ignore_index, label_map, reduce_zero_label) total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label total_area_label += area_label return total_area_intersect, total_area_union, \ total_area_pred_label, total_area_label def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None, label_map=dict(), reduce_zero_label=False): """Calculate Mean Intersection and Union (mIoU) Args: results (list[ndarray]): List of prediction segmentation maps. gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. label_map (dict): Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category IoU, shape (num_classes, ). """ all_acc, acc, iou = eval_metrics( results=results, gt_seg_maps=gt_seg_maps, num_classes=num_classes, ignore_index=ignore_index, metrics=['mIoU'], nan_to_num=nan_to_num, label_map=label_map, reduce_zero_label=reduce_zero_label) return all_acc, acc, iou def mean_dice(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None, label_map=dict(), reduce_zero_label=False): """Calculate Mean Dice (mDice) Args: results (list[ndarray]): List of prediction segmentation maps. gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. label_map (dict): Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category dice, shape (num_classes, ). """ all_acc, acc, dice = eval_metrics( results=results, gt_seg_maps=gt_seg_maps, num_classes=num_classes, ignore_index=ignore_index, metrics=['mDice'], nan_to_num=nan_to_num, label_map=label_map, reduce_zero_label=reduce_zero_label) return all_acc, acc, dice def eval_metrics(results, gt_seg_maps, num_classes, ignore_index, metrics=['mIoU'], nan_to_num=None, label_map=dict(), reduce_zero_label=False): """Calculate evaluation metrics Args: results (list[ndarray]): List of prediction segmentation maps. gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. label_map (dict): Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category evalution metrics, shape (num_classes, ). """ if isinstance(metrics, str): metrics = [metrics] allowed_metrics = ['mIoU', 'mDice'] if not set(metrics).issubset(set(allowed_metrics)): raise KeyError('metrics {} is not supported'.format(metrics)) total_area_intersect, total_area_union, total_area_pred_label, \ total_area_label = total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index, label_map, reduce_zero_label) all_acc = total_area_intersect.sum() / total_area_label.sum() acc = total_area_intersect / total_area_label ret_metrics = [all_acc, acc] for metric in metrics: if metric == 'mIoU': iou = total_area_intersect / total_area_union ret_metrics.append(iou) elif metric == 'mDice': dice = 2 * total_area_intersect / ( total_area_pred_label + total_area_label) ret_metrics.append(dice) if nan_to_num is not None: ret_metrics = [ np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics ] return ret_metrics ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/__init__.py ================================================ from .builder import build_pixel_sampler from .sampler import BasePixelSampler, OHEMPixelSampler __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/builder.py ================================================ from mmcv.utils import Registry, build_from_cfg PIXEL_SAMPLERS = Registry('pixel sampler') def build_pixel_sampler(cfg, **default_args): """Build pixel sampler for segmentation map.""" return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/__init__.py ================================================ from .base_pixel_sampler import BasePixelSampler from .ohem_pixel_sampler import OHEMPixelSampler __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/base_pixel_sampler.py ================================================ from abc import ABCMeta, abstractmethod class BasePixelSampler(metaclass=ABCMeta): """Base class of pixel sampler.""" def __init__(self, **kwargs): pass @abstractmethod def sample(self, seg_logit, seg_label): """Placeholder for sample function.""" pass ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py ================================================ import torch import torch.nn.functional as F from ..builder import PIXEL_SAMPLERS from .base_pixel_sampler import BasePixelSampler @PIXEL_SAMPLERS.register_module() class OHEMPixelSampler(BasePixelSampler): """Online Hard Example Mining Sampler for segmentation. Args: context (nn.Module): The context of sampler, subclass of :obj:`BaseDecodeHead`. thresh (float, optional): The threshold for hard example selection. Below which, are prediction with low confidence. If not specified, the hard examples will be pixels of top ``min_kept`` loss. Default: None. min_kept (int, optional): The minimum number of predictions to keep. Default: 100000. """ def __init__(self, context, thresh=None, min_kept=100000): super(OHEMPixelSampler, self).__init__() self.context = context assert min_kept > 1 self.thresh = thresh self.min_kept = min_kept def sample(self, seg_logit, seg_label): """Sample pixels that have high loss or with low prediction confidence. Args: seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) Returns: torch.Tensor: segmentation weight, shape (N, H, W) """ with torch.no_grad(): assert seg_logit.shape[2:] == seg_label.shape[2:] assert seg_label.shape[1] == 1 seg_label = seg_label.squeeze(1).long() batch_kept = self.min_kept * seg_label.size(0) valid_mask = seg_label != self.context.ignore_index seg_weight = seg_logit.new_zeros(size=seg_label.size()) valid_seg_weight = seg_weight[valid_mask] if self.thresh is not None: seg_prob = F.softmax(seg_logit, dim=1) tmp_seg_label = seg_label.clone().unsqueeze(1) tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) sort_prob, sort_indices = seg_prob[valid_mask].sort() if sort_prob.numel() > 0: min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] else: min_threshold = 0.0 threshold = max(min_threshold, self.thresh) valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. else: losses = self.context.loss_decode( seg_logit, seg_label, weight=None, ignore_index=self.context.ignore_index, reduction_override='none') # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa _, sort_indices = losses[valid_mask].sort(descending=True) valid_seg_weight[sort_indices[:batch_kept]] = 1. seg_weight[valid_mask] = valid_seg_weight return seg_weight ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/utils/__init__.py ================================================ from .misc import add_prefix __all__ = ['add_prefix'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/core/utils/misc.py ================================================ def add_prefix(inputs, prefix): """Add prefix for dict. Args: inputs (dict): The input dict with str keys. prefix (str): The prefix to add. Returns: dict: The dict with keys updated with ``prefix``. """ outputs = dict() for name, value in inputs.items(): outputs[f'{prefix}.{name}'] = value return outputs ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/__init__.py ================================================ from .ade import ADE20KDataset from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset from .chase_db1 import ChaseDB1Dataset from .cityscapes import CityscapesDataset from .custom import CustomDataset from .dataset_wrappers import ConcatDataset, RepeatDataset from .drive import DRIVEDataset from .hrf import HRFDataset from .pascal_context import PascalContextDataset from .stare import STAREDataset from .voc import PascalVOCDataset from .coco_stuff import COCOStuffDataset __all__ = [ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'COCOStuffDataset', ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/ade.py ================================================ from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class ADE20KDataset(CustomDataset): """ADE20K dataset. In segmentation map annotation for ADE20K, 0 stands for background, which is not included in 150 categories. ``reduce_zero_label`` is fixed to True. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to '.png'. """ CLASSES = ( 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag') PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] def __init__(self, **kwargs): super(ADE20KDataset, self).__init__( img_suffix='.jpg', seg_map_suffix='.png', reduce_zero_label=True, **kwargs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/builder.py ================================================ import copy import platform import random from functools import partial import numpy as np from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import Registry, build_from_cfg from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader from torch.utils.data import DistributedSampler if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 import resource rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) hard_limit = rlimit[1] soft_limit = min(4096, hard_limit) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) DATASETS = Registry('dataset') PIPELINES = Registry('pipeline') def _concat_dataset(cfg, default_args=None): """Build :obj:`ConcatDataset by.""" from .dataset_wrappers import ConcatDataset img_dir = cfg['img_dir'] ann_dir = cfg.get('ann_dir', None) split = cfg.get('split', None) num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 if ann_dir is not None: num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 else: num_ann_dir = 0 if split is not None: num_split = len(split) if isinstance(split, (list, tuple)) else 1 else: num_split = 0 if num_img_dir > 1: assert num_img_dir == num_ann_dir or num_ann_dir == 0 assert num_img_dir == num_split or num_split == 0 else: assert num_split == num_ann_dir or num_ann_dir <= 1 num_dset = max(num_split, num_img_dir) datasets = [] for i in range(num_dset): data_cfg = copy.deepcopy(cfg) if isinstance(img_dir, (list, tuple)): data_cfg['img_dir'] = img_dir[i] if isinstance(ann_dir, (list, tuple)): data_cfg['ann_dir'] = ann_dir[i] if isinstance(split, (list, tuple)): data_cfg['split'] = split[i] datasets.append(build_dataset(data_cfg, default_args)) return ConcatDataset(datasets) def build_dataset(cfg, default_args=None): """Build datasets.""" from .dataset_wrappers import ConcatDataset, RepeatDataset if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( cfg.get('split', None), (list, tuple)): dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset def build_dataloader(dataset, samples_per_gpu, workers_per_gpu, num_gpus=1, dist=True, shuffle=True, seed=None, drop_last=False, pin_memory=True, dataloader_type='PoolDataLoader', **kwargs): """Build PyTorch DataLoader. In distributed training, each GPU/process has a dataloader. In non-distributed training, there is only one dataloader for all GPUs. Args: dataset (Dataset): A PyTorch dataset. samples_per_gpu (int): Number of training samples on each GPU, i.e., batch size of each GPU. workers_per_gpu (int): How many subprocesses to use for data loading for each GPU. num_gpus (int): Number of GPUs. Only used in non-distributed training. dist (bool): Distributed training/test or not. Default: True. shuffle (bool): Whether to shuffle the data at every epoch. Default: True. seed (int | None): Seed to be used. Default: None. drop_last (bool): Whether to drop the last incomplete batch in epoch. Default: False pin_memory (bool): Whether to use pin_memory in DataLoader. Default: True dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' kwargs: any keyword argument to be used to initialize DataLoader Returns: DataLoader: A PyTorch dataloader. """ rank, world_size = get_dist_info() if dist: sampler = DistributedSampler( dataset, world_size, rank, shuffle=shuffle) shuffle = False batch_size = samples_per_gpu num_workers = workers_per_gpu else: sampler = None batch_size = num_gpus * samples_per_gpu num_workers = num_gpus * workers_per_gpu init_fn = partial( worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None assert dataloader_type in ( 'DataLoader', 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' if dataloader_type == 'PoolDataLoader': dataloader = PoolDataLoader elif dataloader_type == 'DataLoader': dataloader = DataLoader data_loader = dataloader( dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), pin_memory=pin_memory, shuffle=shuffle, worker_init_fn=init_fn, drop_last=drop_last, **kwargs) return data_loader def worker_init_fn(worker_id, num_workers, rank, seed): """Worker init func for dataloader. The seed of each worker equals to num_worker * rank + worker_id + user_seed Args: worker_id (int): Worker id. num_workers (int): Number of workers. rank (int): The rank of current process. seed (int): The random seed to use. """ worker_seed = num_workers * rank + worker_id + seed np.random.seed(worker_seed) random.seed(worker_seed) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/chase_db1.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class ChaseDB1Dataset(CustomDataset): """Chase_db1 dataset. In segmentation map annotation for Chase_db1, 0 stands for background, which is included in 2 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to '_1stHO.png'. """ CLASSES = ('background', 'vessel') PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): super(ChaseDB1Dataset, self).__init__( img_suffix='.png', seg_map_suffix='_1stHO.png', reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/cityscapes.py ================================================ import os.path as osp import tempfile import mmcv import numpy as np from mmcv.utils import print_log from PIL import Image from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class CityscapesDataset(CustomDataset): """Cityscapes dataset. The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. """ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle') PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]] def __init__(self, **kwargs): super(CityscapesDataset, self).__init__( img_suffix='_leftImg8bit.png', seg_map_suffix='_gtFine_labelTrainIds.png', **kwargs) @staticmethod def _convert_to_label_id(result): """Convert trainId to id for cityscapes.""" if isinstance(result, str): result = np.load(result) import cityscapesscripts.helpers.labels as CSLabels result_copy = result.copy() for trainId, label in CSLabels.trainId2label.items(): result_copy[result == trainId] = label.id return result_copy def results2img(self, results, imgfile_prefix, to_label_id): """Write the segmentation results to images. Args: results (list[list | tuple | ndarray]): Testing results of the dataset. imgfile_prefix (str): The filename prefix of the png files. If the prefix is "somepath/xxx", the png files will be named "somepath/xxx.png". to_label_id (bool): whether convert output to label_id for submission Returns: list[str: str]: result txt files which contains corresponding semantic segmentation images. """ mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] prog_bar = mmcv.ProgressBar(len(self)) for idx in range(len(self)): result = results[idx] if to_label_id: result = self._convert_to_label_id(result) filename = self.img_infos[idx]['filename'] basename = osp.splitext(osp.basename(filename))[0] png_filename = osp.join(imgfile_prefix, f'{basename}.png') output = Image.fromarray(result.astype(np.uint8)).convert('P') import cityscapesscripts.helpers.labels as CSLabels palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) for label_id, label in CSLabels.id2label.items(): palette[label_id] = label.color output.putpalette(palette) output.save(png_filename) result_files.append(png_filename) prog_bar.update() return result_files def format_results(self, results, imgfile_prefix=None, to_label_id=True): """Format the results into dir (standard format for Cityscapes evaluation). Args: results (list): Testing results of the dataset. imgfile_prefix (str | None): The prefix of images files. It includes the file path and the prefix of filename, e.g., "a/b/prefix". If not specified, a temp file will be created. Default: None. to_label_id (bool): whether convert output to label_id for submission. Default: False Returns: tuple: (result_files, tmp_dir), result_files is a list containing the image paths, tmp_dir is the temporal directory created for saving json/png files when img_prefix is not specified. """ assert isinstance(results, list), 'results must be a list' assert len(results) == len(self), ( 'The length of results is not equal to the dataset len: ' f'{len(results)} != {len(self)}') if imgfile_prefix is None: tmp_dir = tempfile.TemporaryDirectory() imgfile_prefix = tmp_dir.name else: tmp_dir = None result_files = self.results2img(results, imgfile_prefix, to_label_id) return result_files, tmp_dir def evaluate(self, results, metric='mIoU', logger=None, imgfile_prefix=None, efficient_test=False): """Evaluation in Cityscapes/default protocol. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. imgfile_prefix (str | None): The prefix of output image file, for cityscapes evaluation only. It includes the file path and the prefix of filename, e.g., "a/b/prefix". If results are evaluated with cityscapes protocol, it would be the prefix of output png files. The output files would be png images under folder "a/b/prefix/xxx.png", where "xxx" is the image name of cityscapes. If not specified, a temp file will be created for evaluation. Default: None. Returns: dict[str, float]: Cityscapes/default metrics. """ eval_results = dict() metrics = metric.copy() if isinstance(metric, list) else [metric] if 'cityscapes' in metrics: eval_results.update( self._evaluate_cityscapes(results, logger, imgfile_prefix)) metrics.remove('cityscapes') if len(metrics) > 0: eval_results.update( super(CityscapesDataset, self).evaluate(results, metrics, logger, efficient_test)) return eval_results def _evaluate_cityscapes(self, results, logger, imgfile_prefix): """Evaluation in Cityscapes protocol. Args: results (list): Testing results of the dataset. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. imgfile_prefix (str | None): The prefix of output image file Returns: dict[str: float]: Cityscapes evaluation results. """ try: import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa except ImportError: raise ImportError('Please run "pip install cityscapesscripts" to ' 'install cityscapesscripts first.') msg = 'Evaluating in Cityscapes style' if logger is None: msg = '\n' + msg print_log(msg, logger=logger) result_files, tmp_dir = self.format_results(results, imgfile_prefix) if tmp_dir is None: result_dir = imgfile_prefix else: result_dir = tmp_dir.name eval_results = dict() print_log(f'Evaluating results under {result_dir} ...', logger=logger) CSEval.args.evalInstLevelScore = True CSEval.args.predictionPath = osp.abspath(result_dir) CSEval.args.evalPixelAccuracy = True CSEval.args.JSONOutput = False seg_map_list = [] pred_list = [] # when evaluating with official cityscapesscripts, # **_gtFine_labelIds.png is used for seg_map in mmcv.scandir( self.ann_dir, 'gtFine_labelIds.png', recursive=True): seg_map_list.append(osp.join(self.ann_dir, seg_map)) pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) eval_results.update( CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) if tmp_dir is not None: tmp_dir.cleanup() return eval_results ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/coco_stuff.py ================================================ from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class COCOStuffDataset(CustomDataset): """COCO-Stuff dataset. In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171 semantic categories. ``reduce_zero_label`` is set to True and False for the 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', and ``seg_map_suffix`` is fixed to '.png'. """ CLASSES = ( 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', 'window-blind', 'window-other', 'wood') PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], [64, 160, 64], [64, 64, 0]] def __init__(self, **kwargs): super(COCOStuffDataset, self).__init__( img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/custom.py ================================================ import os import os.path as osp from functools import reduce import mmcv import numpy as np from mmcv.utils import print_log from terminaltables import AsciiTable from torch.utils.data import Dataset from mmseg.core import eval_metrics from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose @DATASETS.register_module() class CustomDataset(Dataset): """Custom dataset for semantic segmentation. An example of file structure is as followed. .. code-block:: none ├── data │ ├── my_dataset │ │ ├── img_dir │ │ │ ├── train │ │ │ │ ├── xxx{img_suffix} │ │ │ │ ├── yyy{img_suffix} │ │ │ │ ├── zzz{img_suffix} │ │ │ ├── val │ │ ├── ann_dir │ │ │ ├── train │ │ │ │ ├── xxx{seg_map_suffix} │ │ │ │ ├── yyy{seg_map_suffix} │ │ │ │ ├── zzz{seg_map_suffix} │ │ │ ├── val The img/gt_semantic_seg pair of CustomDataset should be of the same except suffix. A valid img/gt_semantic_seg filename pair should be like ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included in the suffix). If split is given, then ``xxx`` is specified in txt file. Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. Please refer to ``docs/tutorials/new_dataset.md`` for more details. Args: pipeline (list[dict]): Processing pipeline img_dir (str): Path to image directory img_suffix (str): Suffix of images. Default: '.jpg' ann_dir (str, optional): Path to annotation directory. Default: None seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' split (str, optional): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None data_root (str, optional): Data root for img_dir/ann_dir. Default: None. test_mode (bool): If test_mode=True, gt wouldn't be loaded. ignore_index (int): The label index to be ignored. Default: 255 reduce_zero_label (bool): Whether to mark label zero as ignored. Default: False classes (str | Sequence[str], optional): Specify classes to load. If is None, ``cls.CLASSES`` will be used. Default: None. palette (Sequence[Sequence[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, and self.PALETTE is None, random palette will be generated. Default: None """ CLASSES = None PALETTE = None def __init__(self, pipeline, img_dir, img_suffix='.jpg', ann_dir=None, seg_map_suffix='.png', split=None, data_root=None, test_mode=False, ignore_index=255, reduce_zero_label=False, classes=None, palette=None): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix self.ann_dir = ann_dir self.seg_map_suffix = seg_map_suffix self.split = split self.data_root = data_root self.test_mode = test_mode self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label self.label_map = None self.CLASSES, self.PALETTE = self.get_classes_and_palette( classes, palette) # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.img_dir): self.img_dir = osp.join(self.data_root, self.img_dir) if not (self.ann_dir is None or osp.isabs(self.ann_dir)): self.ann_dir = osp.join(self.data_root, self.ann_dir) if not (self.split is None or osp.isabs(self.split)): self.split = osp.join(self.data_root, self.split) # load annotations self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.ann_dir, self.seg_map_suffix, self.split) def __len__(self): """Total number of samples of data.""" return len(self.img_infos) def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): """Load annotation from directory. Args: img_dir (str): Path to image directory img_suffix (str): Suffix of images. ann_dir (str|None): Path to annotation directory. seg_map_suffix (str|None): Suffix of segmentation maps. split (str|None): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None Returns: list[dict]: All image info of dataset. """ img_infos = [] if split is not None: with open(split) as f: for line in f: img_name = line.strip() img_info = dict(filename=img_name + img_suffix) if ann_dir is not None: seg_map = img_name + seg_map_suffix img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in mmcv.scandir(img_dir, img_suffix, recursive=True): img_info = dict(filename=img) if ann_dir is not None: seg_map = img.replace(img_suffix, seg_map_suffix) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) return img_infos def get_ann_info(self, idx): """Get annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ return self.img_infos[idx]['ann'] def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['seg_fields'] = [] results['img_prefix'] = self.img_dir results['seg_prefix'] = self.ann_dir if self.custom_classes: results['label_map'] = self.label_map def __getitem__(self, idx): """Get training/test data after pipeline. Args: idx (int): Index of data. Returns: dict: Training/test data (with annotation if `test_mode` is set False). """ if self.test_mode: return self.prepare_test_img(idx) else: return self.prepare_train_img(idx) def prepare_train_img(self, idx): """Get training data and annotations after pipeline. Args: idx (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys introduced by pipeline. """ img_info = self.img_infos[idx] ann_info = self.get_ann_info(idx) results = dict(img_info=img_info, ann_info=ann_info) self.pre_pipeline(results) return self.pipeline(results) def prepare_test_img(self, idx): """Get testing data after pipeline. Args: idx (int): Index of data. Returns: dict: Testing data after pipeline with new keys intorduced by piepline. """ img_info = self.img_infos[idx] results = dict(img_info=img_info) self.pre_pipeline(results) return self.pipeline(results) def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" pass def get_gt_seg_maps(self, efficient_test=False): """Get ground truth segmentation maps for evaluation.""" gt_seg_maps = [] for img_info in self.img_infos: seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) if efficient_test: gt_seg_map = seg_map else: gt_seg_map = mmcv.imread( seg_map, flag='unchanged', backend='pillow') gt_seg_maps.append(gt_seg_map) return gt_seg_maps def get_classes_and_palette(self, classes=None, palette=None): """Get class names of current dataset. Args: classes (Sequence[str] | str | None): If classes is None, use default CLASSES defined by builtin dataset. If classes is a string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. palette (Sequence[Sequence[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None """ if classes is None: self.custom_classes = False return self.CLASSES, self.PALETTE self.custom_classes = True if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') if self.CLASSES: if not set(classes).issubset(self.CLASSES): raise ValueError('classes is not a subset of CLASSES.') # dictionary, its keys are the old label ids and its values # are the new label ids. # used for changing pixel labels in load_annotations. self.label_map = {} for i, c in enumerate(self.CLASSES): if c not in class_names: self.label_map[i] = -1 else: self.label_map[i] = classes.index(c) palette = self.get_palette_for_custom_classes(class_names, palette) return class_names, palette def get_palette_for_custom_classes(self, class_names, palette=None): if self.label_map is not None: # return subset of palette palette = [] for old_id, new_id in sorted( self.label_map.items(), key=lambda x: x[1]): if new_id != -1: palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) elif palette is None: if self.PALETTE is None: palette = np.random.randint(0, 255, size=(len(class_names), 3)) else: palette = self.PALETTE return palette def evaluate(self, results, metric='mIoU', logger=None, efficient_test=False, **kwargs): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. 'mIoU' and 'mDice' are supported. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. Returns: dict[str, float]: Default metrics. """ if isinstance(metric, str): metric = [metric] allowed_metrics = ['mIoU', 'mDice'] if not set(metric).issubset(set(allowed_metrics)): raise KeyError('metric {} is not supported'.format(metric)) eval_results = {} gt_seg_maps = self.get_gt_seg_maps(efficient_test) if self.CLASSES is None: num_classes = len( reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) else: num_classes = len(self.CLASSES) ret_metrics = eval_metrics( results, gt_seg_maps, num_classes, self.ignore_index, metric, label_map=self.label_map, reduce_zero_label=self.reduce_zero_label) class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] if self.CLASSES is None: class_names = tuple(range(num_classes)) else: class_names = self.CLASSES ret_metrics_round = [ np.round(ret_metric * 100, 2) for ret_metric in ret_metrics ] for i in range(num_classes): class_table_data.append([class_names[i]] + [m[i] for m in ret_metrics_round[2:]] + [ret_metrics_round[1][i]]) summary_table_data = [['Scope'] + ['m' + head for head in class_table_data[0][1:]] + ['aAcc']] ret_metrics_mean = [ np.round(np.nanmean(ret_metric) * 100, 2) for ret_metric in ret_metrics ] summary_table_data.append(['global'] + ret_metrics_mean[2:] + [ret_metrics_mean[1]] + [ret_metrics_mean[0]]) print_log('per class results:', logger) table = AsciiTable(class_table_data) print_log('\n' + table.table, logger=logger) print_log('Summary:', logger) table = AsciiTable(summary_table_data) print_log('\n' + table.table, logger=logger) for i in range(1, len(summary_table_data[0])): eval_results[summary_table_data[0] [i]] = summary_table_data[1][i] / 100.0 if mmcv.is_list_of(results, str): for file_name in results: os.remove(file_name) return eval_results ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/dataset_wrappers.py ================================================ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from .builder import DATASETS @DATASETS.register_module() class ConcatDataset(_ConcatDataset): """A wrapper of concatenated dataset. Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but concat the group flag for image aspect ratio. Args: datasets (list[:obj:`Dataset`]): A list of datasets. """ def __init__(self, datasets): super(ConcatDataset, self).__init__(datasets) self.CLASSES = datasets[0].CLASSES self.PALETTE = datasets[0].PALETTE @DATASETS.register_module() class RepeatDataset(object): """A wrapper of repeated dataset. The length of repeated dataset will be `times` larger than the original dataset. This is useful when the data loading time is long but the dataset is small. Using RepeatDataset can reduce the data loading time between epochs. Args: dataset (:obj:`Dataset`): The dataset to be repeated. times (int): Repeat times. """ def __init__(self, dataset, times): self.dataset = dataset self.times = times self.CLASSES = dataset.CLASSES self.PALETTE = dataset.PALETTE self._ori_len = len(self.dataset) def __getitem__(self, idx): """Get item from original dataset.""" return self.dataset[idx % self._ori_len] def __len__(self): """The length is multiplied by ``times``""" return self.times * self._ori_len ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/drive.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class DRIVEDataset(CustomDataset): """DRIVE dataset. In segmentation map annotation for DRIVE, 0 stands for background, which is included in 2 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to '_manual1.png'. """ CLASSES = ('background', 'vessel') PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): super(DRIVEDataset, self).__init__( img_suffix='.png', seg_map_suffix='_manual1.png', reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/hrf.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class HRFDataset(CustomDataset): """HRF dataset. In segmentation map annotation for HRF, 0 stands for background, which is included in 2 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'. """ CLASSES = ('background', 'vessel') PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): super(HRFDataset, self).__init__( img_suffix='.png', seg_map_suffix='.png', reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pascal_context.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class PascalContextDataset(CustomDataset): """PascalContext dataset. In segmentation map annotation for PascalContext, 0 stands for background, which is included in 60 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to '.png'. Args: split (str): Split txt file for PascalContext. """ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building', 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky', 'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water', 'window', 'wood') PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] def __init__(self, split, **kwargs): super(PascalContextDataset, self).__init__( img_suffix='.jpg', seg_map_suffix='.png', split=split, reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir) and self.split is not None ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/__init__.py ================================================ from .compose import Compose from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, Transpose, to_tensor) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, PhotoMetricDistortion, RandomCrop, RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/compose.py ================================================ import collections from mmcv.utils import build_from_cfg from ..builder import PIPELINES @PIPELINES.register_module() class Compose(object): """Compose multiple transforms sequentially. Args: transforms (Sequence[dict | callable]): Sequence of transform object or config dict to be composed. """ def __init__(self, transforms): assert isinstance(transforms, collections.abc.Sequence) self.transforms = [] for transform in transforms: if isinstance(transform, dict): transform = build_from_cfg(transform, PIPELINES) self.transforms.append(transform) elif callable(transform): self.transforms.append(transform) else: raise TypeError('transform must be callable or a dict') def __call__(self, data): """Call function to apply transforms sequentially. Args: data (dict): A result dict contains the data to transform. Returns: dict: Transformed data. """ for t in self.transforms: data = t(data) if data is None: return None return data def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += f' {t}' format_string += '\n)' return format_string ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/formating.py ================================================ from collections.abc import Sequence import mmcv import numpy as np import torch from mmcv.parallel import DataContainer as DC from ..builder import PIPELINES def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. Args: data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to be converted. """ if isinstance(data, torch.Tensor): return data elif isinstance(data, np.ndarray): return torch.from_numpy(data) elif isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) elif isinstance(data, float): return torch.FloatTensor([data]) else: raise TypeError(f'type {type(data)} cannot be converted to tensor.') @PIPELINES.register_module() class ToTensor(object): """Convert some results to :obj:`torch.Tensor` by given keys. Args: keys (Sequence[str]): Keys that need to be converted to Tensor. """ def __init__(self, keys): self.keys = keys def __call__(self, results): """Call function to convert data in results to :obj:`torch.Tensor`. Args: results (dict): Result dict contains the data to convert. Returns: dict: The result dict contains the data converted to :obj:`torch.Tensor`. """ for key in self.keys: results[key] = to_tensor(results[key]) return results def __repr__(self): return self.__class__.__name__ + f'(keys={self.keys})' @PIPELINES.register_module() class ImageToTensor(object): """Convert image to :obj:`torch.Tensor` by given keys. The dimension order of input image is (H, W, C). The pipeline will convert it to (C, H, W). If only 2 dimension (H, W) is given, the output would be (1, H, W). Args: keys (Sequence[str]): Key of images to be converted to Tensor. """ def __init__(self, keys): self.keys = keys def __call__(self, results): """Call function to convert image in results to :obj:`torch.Tensor` and transpose the channel order. Args: results (dict): Result dict contains the image data to convert. Returns: dict: The result dict contains the image converted to :obj:`torch.Tensor` and transposed to (C, H, W) order. """ for key in self.keys: img = results[key] if len(img.shape) < 3: img = np.expand_dims(img, -1) results[key] = to_tensor(img.transpose(2, 0, 1)) return results def __repr__(self): return self.__class__.__name__ + f'(keys={self.keys})' @PIPELINES.register_module() class Transpose(object): """Transpose some results by given keys. Args: keys (Sequence[str]): Keys of results to be transposed. order (Sequence[int]): Order of transpose. """ def __init__(self, keys, order): self.keys = keys self.order = order def __call__(self, results): """Call function to convert image in results to :obj:`torch.Tensor` and transpose the channel order. Args: results (dict): Result dict contains the image data to convert. Returns: dict: The result dict contains the image converted to :obj:`torch.Tensor` and transposed to (C, H, W) order. """ for key in self.keys: results[key] = results[key].transpose(self.order) return results def __repr__(self): return self.__class__.__name__ + \ f'(keys={self.keys}, order={self.order})' @PIPELINES.register_module() class ToDataContainer(object): """Convert results to :obj:`mmcv.DataContainer` by given fields. Args: fields (Sequence[dict]): Each field is a dict like ``dict(key='xxx', **kwargs)``. The ``key`` in result will be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. Default: ``(dict(key='img', stack=True), dict(key='gt_semantic_seg'))``. """ def __init__(self, fields=(dict(key='img', stack=True), dict(key='gt_semantic_seg'))): self.fields = fields def __call__(self, results): """Call function to convert data in results to :obj:`mmcv.DataContainer`. Args: results (dict): Result dict contains the data to convert. Returns: dict: The result dict contains the data converted to :obj:`mmcv.DataContainer`. """ for field in self.fields: field = field.copy() key = field.pop('key') results[key] = DC(results[key], **field) return results def __repr__(self): return self.__class__.__name__ + f'(fields={self.fields})' @PIPELINES.register_module() class DefaultFormatBundle(object): """Default formatting bundle. It simplifies the pipeline of formatting common fields, including "img" and "gt_semantic_seg". These fields are formatted as follows. - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, (3)to DataContainer (stack=True) """ def __call__(self, results): """Call function to transform and format common fields in results. Args: results (dict): Result dict contains the data to convert. Returns: dict: The result dict contains the data that is formatted with default bundle. """ if 'img' in results: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) img = np.ascontiguousarray(img.transpose(2, 0, 1)) results['img'] = DC(to_tensor(img), stack=True) if 'gt_semantic_seg' in results: # convert to long results['gt_semantic_seg'] = DC( to_tensor(results['gt_semantic_seg'][None, ...].astype(np.int64)), stack=True) return results def __repr__(self): return self.__class__.__name__ @PIPELINES.register_module() class Collect(object): """Collect data from the loader relevant to the specific task. This is usually the last stage of the data loader pipeline. Typically keys is set to some subset of "img", "gt_semantic_seg". The "img_meta" item is always populated. The contents of the "img_meta" dictionary depends on "meta_keys". By default this includes: - "img_shape": shape of the image input to the network as a tuple (h, w, c). Note that images may be zero padded on the bottom/right if the batch tensor is larger than this shape. - "scale_factor": a float indicating the preprocessing scale - "flip": a boolean indicating if image flip transform was used - "filename": path to the image file - "ori_shape": original shape of the image as a tuple (h, w, c) - "pad_shape": image shape after padding - "img_norm_cfg": a dict of normalization information: - mean - per channel mean subtraction - std - per channel std divisor - to_rgb - bool indicating if bgr was converted to rgb Args: keys (Sequence[str]): Keys of results to be collected in ``data``. meta_keys (Sequence[str], optional): Meta keys to be converted to ``mmcv.DataContainer`` and collected in ``data[img_metas]``. Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg')`` """ def __init__(self, keys, meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg')): self.keys = keys self.meta_keys = meta_keys def __call__(self, results): """Call function to collect keys in results. The keys in ``meta_keys`` will be converted to :obj:mmcv.DataContainer. Args: results (dict): Result dict contains the data to collect. Returns: dict: The result dict contains the following keys - keys in``self.keys`` - ``img_metas`` """ data = {} img_meta = {} for key in self.meta_keys: img_meta[key] = results[key] data['img_metas'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] return data def __repr__(self): return self.__class__.__name__ + \ f'(keys={self.keys}, meta_keys={self.meta_keys})' ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/loading.py ================================================ import os.path as osp import mmcv import numpy as np from ..builder import PIPELINES @PIPELINES.register_module() class LoadImageFromFile(object): """Load an image from file. Required keys are "img_prefix" and "img_info" (a dict that must contain the key "filename"). Added or updated keys are "filename", "img", "img_shape", "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). Args: to_float32 (bool): Whether to convert the loaded image to a float32 numpy array. If set to False, the loaded image is an uint8 array. Defaults to False. color_type (str): The flag argument for :func:`mmcv.imfrombytes`. Defaults to 'color'. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to ``dict(backend='disk')``. imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 'cv2' """ def __init__(self, to_float32=False, color_type='color', file_client_args=dict(backend='disk'), imdecode_backend='cv2'): self.to_float32 = to_float32 self.color_type = color_type self.file_client_args = file_client_args.copy() self.file_client = None self.imdecode_backend = imdecode_backend def __call__(self, results): """Call functions to load image and get image meta information. Args: results (dict): Result dict from :obj:`mmseg.CustomDataset`. Returns: dict: The dict contains loaded image and meta information. """ if self.file_client is None: self.file_client = mmcv.FileClient(**self.file_client_args) if results.get('img_prefix') is not None: filename = osp.join(results['img_prefix'], results['img_info']['filename']) else: filename = results['img_info']['filename'] img_bytes = self.file_client.get(filename) img = mmcv.imfrombytes( img_bytes, flag=self.color_type, backend=self.imdecode_backend) if self.to_float32: img = img.astype(np.float32) results['filename'] = filename results['ori_filename'] = results['img_info']['filename'] results['img'] = img results['img_shape'] = img.shape results['ori_shape'] = img.shape # Set initial values for default meta_keys results['pad_shape'] = img.shape results['scale_factor'] = 1.0 num_channels = 1 if len(img.shape) < 3 else img.shape[2] results['img_norm_cfg'] = dict( mean=np.zeros(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32), to_rgb=False) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(to_float32={self.to_float32},' repr_str += f"color_type='{self.color_type}'," repr_str += f"imdecode_backend='{self.imdecode_backend}')" return repr_str @PIPELINES.register_module() class LoadAnnotations(object): """Load annotations for semantic segmentation. Args: reduce_zero_label (bool): Whether reduce all label value by 1. Usually used for datasets where 0 is background label. Default: False. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to ``dict(backend='disk')``. imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 'pillow' """ def __init__(self, reduce_zero_label=False, file_client_args=dict(backend='disk'), imdecode_backend='pillow'): self.reduce_zero_label = reduce_zero_label self.file_client_args = file_client_args.copy() self.file_client = None self.imdecode_backend = imdecode_backend def __call__(self, results): """Call function to load multiple types annotations. Args: results (dict): Result dict from :obj:`mmseg.CustomDataset`. Returns: dict: The dict contains loaded semantic segmentation annotations. """ if self.file_client is None: self.file_client = mmcv.FileClient(**self.file_client_args) if results.get('seg_prefix', None) is not None: filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) else: filename = results['ann_info']['seg_map'] img_bytes = self.file_client.get(filename) gt_semantic_seg = mmcv.imfrombytes( img_bytes, flag='unchanged', backend=self.imdecode_backend).squeeze().astype(np.uint8) # modify if custom classes if results.get('label_map', None) is not None: for old_id, new_id in results['label_map'].items(): gt_semantic_seg[gt_semantic_seg == old_id] = new_id # reduce zero_label if self.reduce_zero_label: # avoid using underflow conversion gt_semantic_seg[gt_semantic_seg == 0] = 255 gt_semantic_seg = gt_semantic_seg - 1 gt_semantic_seg[gt_semantic_seg == 254] = 255 results['gt_semantic_seg'] = gt_semantic_seg results['seg_fields'].append('gt_semantic_seg') return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(reduce_zero_label={self.reduce_zero_label},' repr_str += f"imdecode_backend='{self.imdecode_backend}')" return repr_str ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/test_time_aug.py ================================================ import warnings import mmcv from ..builder import PIPELINES from .compose import Compose @PIPELINES.register_module() class MultiScaleFlipAug(object): """Test-time augmentation with multiple scales and flipping. An example configuration is as followed: .. code-block:: img_scale=(2048, 1024), img_ratios=[0.5, 1.0], flip=True, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ] After MultiScaleFLipAug with above configuration, the results are wrapped into lists of the same length as followed: .. code-block:: dict( img=[...], img_shape=[...], scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] flip=[False, True, False, True] ... ) Args: transforms (list[dict]): Transforms to apply in each augmentation. img_scale (None | tuple | list[tuple]): Images scales for resizing. img_ratios (float | list[float]): Image ratios for resizing flip (bool): Whether apply flip augmentation. Default: False. flip_direction (str | list[str]): Flip augmentation directions, options are "horizontal" and "vertical". If flip_direction is list, multiple flip augmentations will be applied. It has no effect when flip == False. Default: "horizontal". """ def __init__(self, transforms, img_scale, img_ratios=None, flip=False, flip_direction='horizontal'): self.transforms = Compose(transforms) if img_ratios is not None: img_ratios = img_ratios if isinstance(img_ratios, list) else [img_ratios] assert mmcv.is_list_of(img_ratios, float) if img_scale is None: # mode 1: given img_scale=None and a range of image ratio self.img_scale = None assert mmcv.is_list_of(img_ratios, float) elif isinstance(img_scale, tuple) and mmcv.is_list_of( img_ratios, float): assert len(img_scale) == 2 # mode 2: given a scale and a range of image ratio self.img_scale = [(int(img_scale[0] * ratio), int(img_scale[1] * ratio)) for ratio in img_ratios] else: # mode 3: given multiple scales self.img_scale = img_scale if isinstance(img_scale, list) else [img_scale] assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None self.flip = flip self.img_ratios = img_ratios self.flip_direction = flip_direction if isinstance( flip_direction, list) else [flip_direction] assert mmcv.is_list_of(self.flip_direction, str) if not self.flip and self.flip_direction != ['horizontal']: warnings.warn( 'flip_direction has no effect when flip is set to False') if (self.flip and not any([t['type'] == 'RandomFlip' for t in transforms])): warnings.warn( 'flip has no effect when RandomFlip is not in transforms') def __call__(self, results): """Call function to apply test time augment transforms on results. Args: results (dict): Result dict contains the data to transform. Returns: dict[str: list]: The augmented data, where each value is wrapped into a list. """ aug_data = [] if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): h, w = results['img'].shape[:2] img_scale = [(int(w * ratio), int(h * ratio)) for ratio in self.img_ratios] else: img_scale = self.img_scale flip_aug = [False, True] if self.flip else [False] for scale in img_scale: for flip in flip_aug: for direction in self.flip_direction: _results = results.copy() _results['scale'] = scale _results['flip'] = flip _results['flip_direction'] = direction data = self.transforms(_results) aug_data.append(data) # list of dict to dict of list aug_data_dict = {key: [] for key in aug_data[0]} for data in aug_data: for key, val in data.items(): aug_data_dict[key].append(val) return aug_data_dict def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(transforms={self.transforms}, ' repr_str += f'img_scale={self.img_scale}, flip={self.flip})' repr_str += f'flip_direction={self.flip_direction}' return repr_str ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/pipelines/transforms.py ================================================ import mmcv import numpy as np from mmcv.utils import deprecated_api_warning, is_tuple_of from numpy import random from ..builder import PIPELINES @PIPELINES.register_module() class Resize(object): """Resize images & seg. This transform resizes the input image to some scale. If the input dict contains the key "scale", then the scale in the input dict is used, otherwise the specified scale in the init method is used. ``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple (multi-scale). There are 4 multiscale modes: - ``ratio_range is not None``: 1. When img_scale is None, img_scale is the shape of image in results (img_scale = results['img'].shape[:2]) and the image is resized based on the original size. (mode 1) 2. When img_scale is a tuple (single-scale), randomly sample a ratio from the ratio range and multiply it with the image scale. (mode 2) - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a scale from the a range. (mode 3) - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a scale from multiple scales. (mode 4) Args: img_scale (tuple or list[tuple]): Images scales for resizing. multiscale_mode (str): Either "range" or "value". ratio_range (tuple[float]): (min_ratio, max_ratio) keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. """ def __init__(self, img_scale=None, multiscale_mode='range', ratio_range=None, keep_ratio=True): if img_scale is None: self.img_scale = None else: if isinstance(img_scale, list): self.img_scale = img_scale else: self.img_scale = [img_scale] assert mmcv.is_list_of(self.img_scale, tuple) if ratio_range is not None: # mode 1: given img_scale=None and a range of image ratio # mode 2: given a scale and a range of image ratio assert self.img_scale is None or len(self.img_scale) == 1 else: # mode 3 and 4: given multiple scales or a range of scales assert multiscale_mode in ['value', 'range'] self.multiscale_mode = multiscale_mode self.ratio_range = ratio_range self.keep_ratio = keep_ratio @staticmethod def random_select(img_scales): """Randomly select an img_scale from given candidates. Args: img_scales (list[tuple]): Images scales for selection. Returns: (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, where ``img_scale`` is the selected image scale and ``scale_idx`` is the selected index in the given candidates. """ assert mmcv.is_list_of(img_scales, tuple) scale_idx = np.random.randint(len(img_scales)) img_scale = img_scales[scale_idx] return img_scale, scale_idx @staticmethod def random_sample(img_scales): """Randomly sample an img_scale when ``multiscale_mode=='range'``. Args: img_scales (list[tuple]): Images scale range for sampling. There must be two tuples in img_scales, which specify the lower and uper bound of image scales. Returns: (tuple, None): Returns a tuple ``(img_scale, None)``, where ``img_scale`` is sampled scale and None is just a placeholder to be consistent with :func:`random_select`. """ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 img_scale_long = [max(s) for s in img_scales] img_scale_short = [min(s) for s in img_scales] long_edge = np.random.randint( min(img_scale_long), max(img_scale_long) + 1) short_edge = np.random.randint( min(img_scale_short), max(img_scale_short) + 1) img_scale = (long_edge, short_edge) return img_scale, None @staticmethod def random_sample_ratio(img_scale, ratio_range): """Randomly sample an img_scale when ``ratio_range`` is specified. A ratio will be randomly sampled from the range specified by ``ratio_range``. Then it would be multiplied with ``img_scale`` to generate sampled scale. Args: img_scale (tuple): Images scale base to multiply with ratio. ratio_range (tuple[float]): The minimum and maximum ratio to scale the ``img_scale``. Returns: (tuple, None): Returns a tuple ``(scale, None)``, where ``scale`` is sampled ratio multiplied with ``img_scale`` and None is just a placeholder to be consistent with :func:`random_select`. """ assert isinstance(img_scale, tuple) and len(img_scale) == 2 min_ratio, max_ratio = ratio_range assert min_ratio <= max_ratio ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) return scale, None def _random_scale(self, results): """Randomly sample an img_scale according to ``ratio_range`` and ``multiscale_mode``. If ``ratio_range`` is specified, a ratio will be sampled and be multiplied with ``img_scale``. If multiple scales are specified by ``img_scale``, a scale will be sampled according to ``multiscale_mode``. Otherwise, single scale will be used. Args: results (dict): Result dict from :obj:`dataset`. Returns: dict: Two new keys 'scale` and 'scale_idx` are added into ``results``, which would be used by subsequent pipelines. """ if self.ratio_range is not None: if self.img_scale is None: h, w = results['img'].shape[:2] scale, scale_idx = self.random_sample_ratio((w, h), self.ratio_range) else: scale, scale_idx = self.random_sample_ratio( self.img_scale[0], self.ratio_range) elif len(self.img_scale) == 1: scale, scale_idx = self.img_scale[0], 0 elif self.multiscale_mode == 'range': scale, scale_idx = self.random_sample(self.img_scale) elif self.multiscale_mode == 'value': scale, scale_idx = self.random_select(self.img_scale) else: raise NotImplementedError results['scale'] = scale results['scale_idx'] = scale_idx def _resize_img(self, results): """Resize images with ``results['scale']``.""" if self.keep_ratio: img, scale_factor = mmcv.imrescale( results['img'], results['scale'], return_scale=True) # the w_scale and h_scale has minor difference # a real fix should be done in the mmcv.imrescale in the future new_h, new_w = img.shape[:2] h, w = results['img'].shape[:2] w_scale = new_w / w h_scale = new_h / h else: img, w_scale, h_scale = mmcv.imresize( results['img'], results['scale'], return_scale=True) scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) results['img'] = img results['img_shape'] = img.shape results['pad_shape'] = img.shape # in case that there is no padding results['scale_factor'] = scale_factor results['keep_ratio'] = self.keep_ratio def _resize_seg(self, results): """Resize semantic segmentation map with ``results['scale']``.""" for key in results.get('seg_fields', []): if self.keep_ratio: gt_seg = mmcv.imrescale( results[key], results['scale'], interpolation='nearest') else: gt_seg = mmcv.imresize( results[key], results['scale'], interpolation='nearest') results[key] = gt_seg def __call__(self, results): """Call function to resize images, bounding boxes, masks, semantic segmentation map. Args: results (dict): Result dict from loading pipeline. Returns: dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', 'keep_ratio' keys are added into result dict. """ if 'scale' not in results: self._random_scale(results) self._resize_img(results) self._resize_seg(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(img_scale={self.img_scale}, ' f'multiscale_mode={self.multiscale_mode}, ' f'ratio_range={self.ratio_range}, ' f'keep_ratio={self.keep_ratio})') return repr_str @PIPELINES.register_module() class RandomFlip(object): """Flip the image & seg. If the input dict contains the key "flip", then the flag will be used, otherwise it will be randomly decided by a ratio specified in the init method. Args: prob (float, optional): The flipping probability. Default: None. direction(str, optional): The flipping direction. Options are 'horizontal' and 'vertical'. Default: 'horizontal'. """ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') def __init__(self, prob=None, direction='horizontal'): self.prob = prob self.direction = direction if prob is not None: assert prob >= 0 and prob <= 1 assert direction in ['horizontal', 'vertical'] def __call__(self, results): """Call function to flip bounding boxes, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Flipped results, 'flip', 'flip_direction' keys are added into result dict. """ if 'flip' not in results: flip = True if np.random.rand() < self.prob else False results['flip'] = flip if 'flip_direction' not in results: results['flip_direction'] = self.direction if results['flip']: # flip image results['img'] = mmcv.imflip( results['img'], direction=results['flip_direction']) # flip segs for key in results.get('seg_fields', []): # use copy() to make numpy stride positive results[key] = mmcv.imflip( results[key], direction=results['flip_direction']).copy() return results def __repr__(self): return self.__class__.__name__ + f'(prob={self.prob})' @PIPELINES.register_module() class Pad(object): """Pad the image & mask. There are two padding modes: (1) pad to a fixed size and (2) pad to the minimum size that is divisible by some number. Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", Args: size (tuple, optional): Fixed padding size. size_divisor (int, optional): The divisor of padded size. pad_val (float, optional): Padding value. Default: 0. seg_pad_val (float, optional): Padding value of segmentation map. Default: 255. """ def __init__(self, size=None, size_divisor=None, pad_val=0, seg_pad_val=255): self.size = size self.size_divisor = size_divisor self.pad_val = pad_val self.seg_pad_val = seg_pad_val # only one of size and size_divisor should be valid assert size is not None or size_divisor is not None assert size is None or size_divisor is None def _pad_img(self, results): """Pad images according to ``self.size``.""" if self.size is not None: padded_img = mmcv.impad( results['img'], shape=self.size, pad_val=self.pad_val) elif self.size_divisor is not None: padded_img = mmcv.impad_to_multiple( results['img'], self.size_divisor, pad_val=self.pad_val) results['img'] = padded_img results['pad_shape'] = padded_img.shape results['pad_fixed_size'] = self.size results['pad_size_divisor'] = self.size_divisor def _pad_seg(self, results): """Pad masks according to ``results['pad_shape']``.""" for key in results.get('seg_fields', []): results[key] = mmcv.impad( results[key], shape=results['pad_shape'][:2], pad_val=self.seg_pad_val) def __call__(self, results): """Call function to pad images, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Updated result dict. """ self._pad_img(results) self._pad_seg(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ f'pad_val={self.pad_val})' return repr_str @PIPELINES.register_module() class Normalize(object): """Normalize the image. Added key is "img_norm_cfg". Args: mean (sequence): Mean values of 3 channels. std (sequence): Std values of 3 channels. to_rgb (bool): Whether to convert the image from BGR to RGB, default is true. """ def __init__(self, mean, std, to_rgb=True): self.mean = np.array(mean, dtype=np.float32) self.std = np.array(std, dtype=np.float32) self.to_rgb = to_rgb def __call__(self, results): """Call function to normalize images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Normalized results, 'img_norm_cfg' key is added into result dict. """ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, self.to_rgb) results['img_norm_cfg'] = dict( mean=self.mean, std=self.std, to_rgb=self.to_rgb) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ f'{self.to_rgb})' return repr_str @PIPELINES.register_module() class Rerange(object): """Rerange the image pixel value. Args: min_value (float or int): Minimum value of the reranged image. Default: 0. max_value (float or int): Maximum value of the reranged image. Default: 255. """ def __init__(self, min_value=0, max_value=255): assert isinstance(min_value, float) or isinstance(min_value, int) assert isinstance(max_value, float) or isinstance(max_value, int) assert min_value < max_value self.min_value = min_value self.max_value = max_value def __call__(self, results): """Call function to rerange images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Reranged results. """ img = results['img'] img_min_value = np.min(img) img_max_value = np.max(img) assert img_min_value < img_max_value # rerange to [0, 1] img = (img - img_min_value) / (img_max_value - img_min_value) # rerange to [min_value, max_value] img = img * (self.max_value - self.min_value) + self.min_value results['img'] = img return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' return repr_str @PIPELINES.register_module() class CLAHE(object): """Use CLAHE method to process the image. See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. Graphics Gems, 1994:474-485.` for more information. Args: clip_limit (float): Threshold for contrast limiting. Default: 40.0. tile_grid_size (tuple[int]): Size of grid for histogram equalization. Input image will be divided into equally sized rectangular tiles. It defines the number of tiles in row and column. Default: (8, 8). """ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): assert isinstance(clip_limit, (float, int)) self.clip_limit = clip_limit assert is_tuple_of(tile_grid_size, int) assert len(tile_grid_size) == 2 self.tile_grid_size = tile_grid_size def __call__(self, results): """Call function to Use CLAHE method process images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Processed results. """ for i in range(results['img'].shape[2]): results['img'][:, :, i] = mmcv.clahe( np.array(results['img'][:, :, i], dtype=np.uint8), self.clip_limit, self.tile_grid_size) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(clip_limit={self.clip_limit}, '\ f'tile_grid_size={self.tile_grid_size})' return repr_str @PIPELINES.register_module() class RandomCrop(object): """Random crop the image & seg. Args: crop_size (tuple): Expected size after cropping, (h, w). cat_max_ratio (float): The maximum ratio that single category could occupy. """ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): assert crop_size[0] > 0 and crop_size[1] > 0 self.crop_size = crop_size self.cat_max_ratio = cat_max_ratio self.ignore_index = ignore_index def get_crop_bbox(self, img): """Randomly get a crop bounding box.""" margin_h = max(img.shape[0] - self.crop_size[0], 0) margin_w = max(img.shape[1] - self.crop_size[1], 0) offset_h = np.random.randint(0, margin_h + 1) offset_w = np.random.randint(0, margin_w + 1) crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] return crop_y1, crop_y2, crop_x1, crop_x2 def crop(self, img, crop_bbox): """Crop from ``img``""" crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] return img def __call__(self, results): """Call function to randomly crop images, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Randomly cropped results, 'img_shape' key in result dict is updated according to crop size. """ img = results['img'] crop_bbox = self.get_crop_bbox(img) if self.cat_max_ratio < 1.: # Repeat 10 times for _ in range(10): seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) labels, cnt = np.unique(seg_temp, return_counts=True) cnt = cnt[labels != self.ignore_index] if len(cnt) > 1 and np.max(cnt) / np.sum( cnt) < self.cat_max_ratio: break crop_bbox = self.get_crop_bbox(img) # crop the image img = self.crop(img, crop_bbox) img_shape = img.shape results['img'] = img results['img_shape'] = img_shape # crop semantic seg for key in results.get('seg_fields', []): results[key] = self.crop(results[key], crop_bbox) return results def __repr__(self): return self.__class__.__name__ + f'(crop_size={self.crop_size})' @PIPELINES.register_module() class RandomRotate(object): """Rotate the image & seg. Args: prob (float): The rotation probability. degree (float, tuple[float]): Range of degrees to select from. If degree is a number instead of tuple like (min, max), the range of degree will be (``-degree``, ``+degree``) pad_val (float, optional): Padding value of image. Default: 0. seg_pad_val (float, optional): Padding value of segmentation map. Default: 255. center (tuple[float], optional): Center point (w, h) of the rotation in the source image. If not specified, the center of the image will be used. Default: None. auto_bound (bool): Whether to adjust the image size to cover the whole rotated image. Default: False """ def __init__(self, prob, degree, pad_val=0, seg_pad_val=255, center=None, auto_bound=False): self.prob = prob assert prob >= 0 and prob <= 1 if isinstance(degree, (float, int)): assert degree > 0, f'degree {degree} should be positive' self.degree = (-degree, degree) else: self.degree = degree assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ f'tuple of (min, max)' self.pal_val = pad_val self.seg_pad_val = seg_pad_val self.center = center self.auto_bound = auto_bound def __call__(self, results): """Call function to rotate image, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Rotated results. """ rotate = True if np.random.rand() < self.prob else False degree = np.random.uniform(min(*self.degree), max(*self.degree)) if rotate: # rotate image results['img'] = mmcv.imrotate( results['img'], angle=degree, border_value=self.pal_val, center=self.center, auto_bound=self.auto_bound) # rotate segs for key in results.get('seg_fields', []): results[key] = mmcv.imrotate( results[key], angle=degree, border_value=self.seg_pad_val, center=self.center, auto_bound=self.auto_bound, interpolation='nearest') return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(prob={self.prob}, ' \ f'degree={self.degree}, ' \ f'pad_val={self.pal_val}, ' \ f'seg_pad_val={self.seg_pad_val}, ' \ f'center={self.center}, ' \ f'auto_bound={self.auto_bound})' return repr_str @PIPELINES.register_module() class RGB2Gray(object): """Convert RGB image to grayscale image. This transform calculate the weighted mean of input image channels with ``weights`` and then expand the channels to ``out_channels``. When ``out_channels`` is None, the number of output channels is the same as input channels. Args: out_channels (int): Expected number of output channels after transforming. Default: None. weights (tuple[float]): The weights to calculate the weighted mean. Default: (0.299, 0.587, 0.114). """ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): assert out_channels is None or out_channels > 0 self.out_channels = out_channels assert isinstance(weights, tuple) for item in weights: assert isinstance(item, (float, int)) self.weights = weights def __call__(self, results): """Call function to convert RGB image to grayscale image. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with grayscale image. """ img = results['img'] assert len(img.shape) == 3 assert img.shape[2] == len(self.weights) weights = np.array(self.weights).reshape((1, 1, -1)) img = (img * weights).sum(2, keepdims=True) if self.out_channels is None: img = img.repeat(weights.shape[2], axis=2) else: img = img.repeat(self.out_channels, axis=2) results['img'] = img results['img_shape'] = img.shape return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(out_channels={self.out_channels}, ' \ f'weights={self.weights})' return repr_str @PIPELINES.register_module() class AdjustGamma(object): """Using gamma correction to process the image. Args: gamma (float or int): Gamma value used in gamma correction. Default: 1.0. """ def __init__(self, gamma=1.0): assert isinstance(gamma, float) or isinstance(gamma, int) assert gamma > 0 self.gamma = gamma inv_gamma = 1.0 / gamma self.table = np.array([(i / 255.0)**inv_gamma * 255 for i in np.arange(256)]).astype('uint8') def __call__(self, results): """Call function to process the image with gamma correction. Args: results (dict): Result dict from loading pipeline. Returns: dict: Processed results. """ results['img'] = mmcv.lut_transform( np.array(results['img'], dtype=np.uint8), self.table) return results def __repr__(self): return self.__class__.__name__ + f'(gamma={self.gamma})' @PIPELINES.register_module() class SegRescale(object): """Rescale semantic segmentation maps. Args: scale_factor (float): The scale factor of the final output. """ def __init__(self, scale_factor=1): self.scale_factor = scale_factor def __call__(self, results): """Call function to scale the semantic segmentation map. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with semantic segmentation map scaled. """ for key in results.get('seg_fields', []): if self.scale_factor != 1: results[key] = mmcv.imrescale( results[key], self.scale_factor, interpolation='nearest') return results def __repr__(self): return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' @PIPELINES.register_module() class PhotoMetricDistortion(object): """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. 1. random brightness 2. random contrast (mode 0) 3. convert color from BGR to HSV 4. random saturation 5. random hue 6. convert color from HSV to BGR 7. random contrast (mode 1) 8. randomly swap channels Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. saturation_range (tuple): range of saturation. hue_delta (int): delta of hue. """ def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta def convert(self, img, alpha=1, beta=0): """Multiple with alpha and add beat with clip.""" img = img.astype(np.float32) * alpha + beta img = np.clip(img, 0, 255) return img.astype(np.uint8) def brightness(self, img): """Brightness distortion.""" if random.randint(2): return self.convert( img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)) return img def contrast(self, img): """Contrast distortion.""" if random.randint(2): return self.convert( img, alpha=random.uniform(self.contrast_lower, self.contrast_upper)) return img def saturation(self, img): """Saturation distortion.""" if random.randint(2): img = mmcv.bgr2hsv(img) img[:, :, 1] = self.convert( img[:, :, 1], alpha=random.uniform(self.saturation_lower, self.saturation_upper)) img = mmcv.hsv2bgr(img) return img def hue(self, img): """Hue distortion.""" if random.randint(2): img = mmcv.bgr2hsv(img) img[:, :, 0] = (img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)) % 180 img = mmcv.hsv2bgr(img) return img def __call__(self, results): """Call function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with images distorted. """ img = results['img'] # random brightness img = self.brightness(img) # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last mode = random.randint(2) if mode == 1: img = self.contrast(img) # random saturation img = self.saturation(img) # random hue img = self.hue(img) # random contrast if mode == 0: img = self.contrast(img) results['img'] = img return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(brightness_delta={self.brightness_delta}, ' f'contrast_range=({self.contrast_lower}, ' f'{self.contrast_upper}), ' f'saturation_range=({self.saturation_lower}, ' f'{self.saturation_upper}), ' f'hue_delta={self.hue_delta})') return repr_str ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/stare.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class STAREDataset(CustomDataset): """STARE dataset. In segmentation map annotation for STARE, 0 stands for background, which is included in 2 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to '.ah.png'. """ CLASSES = ('background', 'vessel') PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): super(STAREDataset, self).__init__( img_suffix='.png', seg_map_suffix='.ah.png', reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/datasets/voc.py ================================================ import os.path as osp from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class PascalVOCDataset(CustomDataset): """Pascal VOC dataset. Args: split (str): Split txt file for Pascal VOC. """ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] def __init__(self, split, **kwargs): super(PascalVOCDataset, self).__init__( img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) assert osp.exists(self.img_dir) and self.split is not None ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/__init__.py ================================================ from .backbones import * # noqa: F401,F403 from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, build_head, build_loss, build_segmentor) from .decode_heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .segmentors import * # noqa: F401,F403 __all__ = [ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 'build_head', 'build_loss', 'build_segmentor' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/__init__.py ================================================ from .cgnet import CGNet from .fast_scnn import FastSCNN from .hrnet import HRNet from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt from .unet import UNet __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/cgnet.py ================================================ import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer, constant_init, kaiming_init) from mmcv.runner import load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.utils import get_root_logger from ..builder import BACKBONES class GlobalContextExtractor(nn.Module): """Global Context Extractor for CGNet. This class is employed to refine the joFint feature of both local feature and surrounding context. Args: channel (int): Number of input feature channels. reduction (int): Reductions for global context extractor. Default: 16. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, channel, reduction=16, with_cp=False): super(GlobalContextExtractor, self).__init__() self.channel = channel self.reduction = reduction assert reduction >= 1 and channel >= reduction self.with_cp = with_cp self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), nn.Sigmoid()) def forward(self, x): def _inner_forward(x): num_batch, num_channel = x.size()[:2] y = self.avg_pool(x).view(num_batch, num_channel) y = self.fc(y).view(num_batch, num_channel, 1, 1) return x * y if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class ContextGuidedBlock(nn.Module): """Context Guided Block for CGNet. This class consists of four components: local feature extractor, surrounding feature extractor, joint feature extractor and global context extractor. Args: in_channels (int): Number of input feature channels. out_channels (int): Number of output feature channels. dilation (int): Dilation rate for surrounding context extractor. Default: 2. reduction (int): Reduction for global context extractor. Default: 16. skip_connect (bool): Add input to output or not. Default: True. downsample (bool): Downsample the input to 1/2 or not. Default: False. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. Default: dict(type='PReLU'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, in_channels, out_channels, dilation=2, reduction=16, skip_connect=True, downsample=False, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='PReLU'), with_cp=False): super(ContextGuidedBlock, self).__init__() self.with_cp = with_cp self.downsample = downsample channels = out_channels if downsample else out_channels // 2 if 'type' in act_cfg and act_cfg['type'] == 'PReLU': act_cfg['num_parameters'] = channels kernel_size = 3 if downsample else 1 stride = 2 if downsample else 1 padding = (kernel_size - 1) // 2 self.conv1x1 = ConvModule( in_channels, channels, kernel_size, stride, padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.f_loc = build_conv_layer( conv_cfg, channels, channels, kernel_size=3, padding=1, groups=channels, bias=False) self.f_sur = build_conv_layer( conv_cfg, channels, channels, kernel_size=3, padding=dilation, groups=channels, dilation=dilation, bias=False) self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] self.activate = nn.PReLU(2 * channels) if downsample: self.bottleneck = build_conv_layer( conv_cfg, 2 * channels, out_channels, kernel_size=1, bias=False) self.skip_connect = skip_connect and not downsample self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) def forward(self, x): def _inner_forward(x): out = self.conv1x1(x) loc = self.f_loc(out) sur = self.f_sur(out) joi_feat = torch.cat([loc, sur], 1) # the joint feature joi_feat = self.bn(joi_feat) joi_feat = self.activate(joi_feat) if self.downsample: joi_feat = self.bottleneck(joi_feat) # channel = out_channels # f_glo is employed to refine the joint feature out = self.f_glo(joi_feat) if self.skip_connect: return x + out else: return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class InputInjection(nn.Module): """Downsampling module for CGNet.""" def __init__(self, num_downsampling): super(InputInjection, self).__init__() self.pool = nn.ModuleList() for i in range(num_downsampling): self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) def forward(self, x): for pool in self.pool: x = pool(x) return x @BACKBONES.register_module() class CGNet(nn.Module): """CGNet backbone. A Light-weight Context Guided Network for Semantic Segmentation arXiv: https://arxiv.org/abs/1811.08201 Args: in_channels (int): Number of input image channels. Normally 3. num_channels (tuple[int]): Numbers of feature channels at each stages. Default: (32, 64, 128). num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. Default: (3, 21). dilations (tuple[int]): Dilation rate for surrounding context extractors at stage 1 and stage 2. Default: (2, 4). reductions (tuple[int]): Reductions for global context extractors at stage 1 and stage 2. Default: (8, 16). conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. Default: dict(type='PReLU'). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, in_channels=3, num_channels=(32, 64, 128), num_blocks=(3, 21), dilations=(2, 4), reductions=(8, 16), conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='PReLU'), norm_eval=False, with_cp=False): super(CGNet, self).__init__() self.in_channels = in_channels self.num_channels = num_channels assert isinstance(self.num_channels, tuple) and len( self.num_channels) == 3 self.num_blocks = num_blocks assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 self.dilations = dilations assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 self.reductions = reductions assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': self.act_cfg['num_parameters'] = num_channels[0] self.norm_eval = norm_eval self.with_cp = with_cp cur_channels = in_channels self.stem = nn.ModuleList() for i in range(3): self.stem.append( ConvModule( cur_channels, num_channels[0], 3, 2 if i == 0 else 1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) cur_channels = num_channels[0] self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 cur_channels += in_channels self.norm_prelu_0 = nn.Sequential( build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels)) # stage 1 self.level1 = nn.ModuleList() for i in range(num_blocks[0]): self.level1.append( ContextGuidedBlock( cur_channels if i == 0 else num_channels[1], num_channels[1], dilations[0], reductions[0], downsample=(i == 0), conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_cp=with_cp)) # CG block cur_channels = 2 * num_channels[1] + in_channels self.norm_prelu_1 = nn.Sequential( build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels)) # stage 2 self.level2 = nn.ModuleList() for i in range(num_blocks[1]): self.level2.append( ContextGuidedBlock( cur_channels if i == 0 else num_channels[2], num_channels[2], dilations[1], reductions[1], downsample=(i == 0), conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_cp=with_cp)) # CG block cur_channels = 2 * num_channels[2] self.norm_prelu_2 = nn.Sequential( build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels)) def forward(self, x): output = [] # stage 0 inp_2x = self.inject_2x(x) inp_4x = self.inject_4x(x) for layer in self.stem: x = layer(x) x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) output.append(x) # stage 1 for i, layer in enumerate(self.level1): x = layer(x) if i == 0: down1 = x x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) output.append(x) # stage 2 for i, layer in enumerate(self.level2): x = layer(x) if i == 0: down2 = x x = self.norm_prelu_2(torch.cat([down2, x], 1)) output.append(x) return output def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) elif isinstance(m, nn.PReLU): constant_init(m, 0) else: raise TypeError('pretrained must be a str or None') def train(self, mode=True): """Convert the model into training mode whill keeping the normalization layer freezed.""" super(CGNet, self).train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/fast_scnn.py ================================================ import torch import torch.nn as nn from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init, kaiming_init) from torch.nn.modules.batchnorm import _BatchNorm from mmseg.models.decode_heads.psp_head import PPM from mmseg.ops import resize from ..builder import BACKBONES from ..utils.inverted_residual import InvertedResidual class LearningToDownsample(nn.Module): """Learning to downsample module. Args: in_channels (int): Number of input channels. dw_channels (tuple[int]): Number of output channels of the first and the second depthwise conv (dwconv) layers. out_channels (int): Number of output channels of the whole 'learning to downsample' module. conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') """ def __init__(self, in_channels, dw_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU')): super(LearningToDownsample, self).__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg dw_channels1 = dw_channels[0] dw_channels2 = dw_channels[1] self.conv = ConvModule( in_channels, dw_channels1, 3, stride=2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.dsconv1 = DepthwiseSeparableConvModule( dw_channels1, dw_channels2, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg) self.dsconv2 = DepthwiseSeparableConvModule( dw_channels2, out_channels, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg) def forward(self, x): x = self.conv(x) x = self.dsconv1(x) x = self.dsconv2(x) return x class GlobalFeatureExtractor(nn.Module): """Global feature extractor module. Args: in_channels (int): Number of input channels of the GFE module. Default: 64 block_channels (tuple[int]): Tuple of ints. Each int specifies the number of output channels of each Inverted Residual module. Default: (64, 96, 128) out_channels(int): Number of output channels of the GFE module. Default: 128 expand_ratio (int): Adjusts number of channels of the hidden layer in InvertedResidual by this amount. Default: 6 num_blocks (tuple[int]): Tuple of ints. Each int specifies the number of times each Inverted Residual module is repeated. The repeated Inverted Residual modules are called a 'group'. Default: (3, 3, 3) strides (tuple[int]): Tuple of ints. Each int specifies the downsampling factor of each 'group'. Default: (2, 2, 1) pool_scales (tuple[int]): Tuple of ints. Each int specifies the parameter required in 'global average pooling' within PPM. Default: (1, 2, 3, 6) conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') align_corners (bool): align_corners argument of F.interpolate. Default: False """ def __init__(self, in_channels=64, block_channels=(64, 96, 128), out_channels=128, expand_ratio=6, num_blocks=(3, 3, 3), strides=(2, 2, 1), pool_scales=(1, 2, 3, 6), conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False): super(GlobalFeatureExtractor, self).__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg assert len(block_channels) == len(num_blocks) == 3 self.bottleneck1 = self._make_layer(in_channels, block_channels[0], num_blocks[0], strides[0], expand_ratio) self.bottleneck2 = self._make_layer(block_channels[0], block_channels[1], num_blocks[1], strides[1], expand_ratio) self.bottleneck3 = self._make_layer(block_channels[1], block_channels[2], num_blocks[2], strides[2], expand_ratio) self.ppm = PPM( pool_scales, block_channels[2], block_channels[2] // 4, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=align_corners) self.out = ConvModule( block_channels[2] * 2, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def _make_layer(self, in_channels, out_channels, blocks, stride=1, expand_ratio=6): layers = [ InvertedResidual( in_channels, out_channels, stride, expand_ratio, norm_cfg=self.norm_cfg) ] for i in range(1, blocks): layers.append( InvertedResidual( out_channels, out_channels, 1, expand_ratio, norm_cfg=self.norm_cfg)) return nn.Sequential(*layers) def forward(self, x): x = self.bottleneck1(x) x = self.bottleneck2(x) x = self.bottleneck3(x) x = torch.cat([x, *self.ppm(x)], dim=1) x = self.out(x) return x class FeatureFusionModule(nn.Module): """Feature fusion module. Args: higher_in_channels (int): Number of input channels of the higher-resolution branch. lower_in_channels (int): Number of input channels of the lower-resolution branch. out_channels (int): Number of output channels. conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') align_corners (bool): align_corners argument of F.interpolate. Default: False """ def __init__(self, higher_in_channels, lower_in_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False): super(FeatureFusionModule, self).__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.align_corners = align_corners self.dwconv = ConvModule( lower_in_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.conv_lower_res = ConvModule( out_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=None) self.conv_higher_res = ConvModule( higher_in_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=None) self.relu = nn.ReLU(True) def forward(self, higher_res_feature, lower_res_feature): lower_res_feature = resize( lower_res_feature, size=higher_res_feature.size()[2:], mode='bilinear', align_corners=self.align_corners) lower_res_feature = self.dwconv(lower_res_feature) lower_res_feature = self.conv_lower_res(lower_res_feature) higher_res_feature = self.conv_higher_res(higher_res_feature) out = higher_res_feature + lower_res_feature return self.relu(out) @BACKBONES.register_module() class FastSCNN(nn.Module): """Fast-SCNN Backbone. Args: in_channels (int): Number of input image channels. Default: 3. downsample_dw_channels (tuple[int]): Number of output channels after the first conv layer & the second conv layer in Learning-To-Downsample (LTD) module. Default: (32, 48). global_in_channels (int): Number of input channels of Global Feature Extractor(GFE). Equal to number of output channels of LTD. Default: 64. global_block_channels (tuple[int]): Tuple of integers that describe the output channels for each of the MobileNet-v2 bottleneck residual blocks in GFE. Default: (64, 96, 128). global_block_strides (tuple[int]): Tuple of integers that describe the strides (downsampling factors) for each of the MobileNet-v2 bottleneck residual blocks in GFE. Default: (2, 2, 1). global_out_channels (int): Number of output channels of GFE. Default: 128. higher_in_channels (int): Number of input channels of the higher resolution branch in FFM. Equal to global_in_channels. Default: 64. lower_in_channels (int): Number of input channels of the lower resolution branch in FFM. Equal to global_out_channels. Default: 128. fusion_out_channels (int): Number of output channels of FFM. Default: 128. out_indices (tuple): Tuple of indices of list [higher_res_features, lower_res_features, fusion_output]. Often set to (0,1,2) to enable aux. heads. Default: (0, 1, 2). conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') align_corners (bool): align_corners argument of F.interpolate. Default: False """ def __init__(self, in_channels=3, downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), global_block_strides=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, fusion_out_channels=128, out_indices=(0, 1, 2), conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False): super(FastSCNN, self).__init__() if global_in_channels != higher_in_channels: raise AssertionError('Global Input Channels must be the same \ with Higher Input Channels!') elif global_out_channels != lower_in_channels: raise AssertionError('Global Output Channels must be the same \ with Lower Input Channels!') self.in_channels = in_channels self.downsample_dw_channels1 = downsample_dw_channels[0] self.downsample_dw_channels2 = downsample_dw_channels[1] self.global_in_channels = global_in_channels self.global_block_channels = global_block_channels self.global_block_strides = global_block_strides self.global_out_channels = global_out_channels self.higher_in_channels = higher_in_channels self.lower_in_channels = lower_in_channels self.fusion_out_channels = fusion_out_channels self.out_indices = out_indices self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.align_corners = align_corners self.learning_to_downsample = LearningToDownsample( in_channels, downsample_dw_channels, global_in_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.global_feature_extractor = GlobalFeatureExtractor( global_in_channels, global_block_channels, global_out_channels, strides=self.global_block_strides, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) self.feature_fusion = FeatureFusionModule( higher_in_channels, lower_in_channels, fusion_out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) def init_weights(self, pretrained=None): for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) def forward(self, x): higher_res_features = self.learning_to_downsample(x) lower_res_features = self.global_feature_extractor(higher_res_features) fusion_output = self.feature_fusion(higher_res_features, lower_res_features) outs = [higher_res_features, lower_res_features, fusion_output] outs = [outs[i] for i in self.out_indices] return tuple(outs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/hrnet.py ================================================ import torch.nn as nn from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, kaiming_init) from mmcv.runner import load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.ops import Upsample, resize from mmseg.utils import get_root_logger from ..builder import BACKBONES from .resnet import BasicBlock, Bottleneck class HRModule(nn.Module): """High-Resolution Module for HRNet. In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange is in this module. """ def __init__(self, num_branches, blocks, num_blocks, in_channels, num_channels, multiscale_output=True, with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True)): super(HRModule, self).__init__() self._check_branches(num_branches, num_blocks, in_channels, num_channels) self.in_channels = in_channels self.num_branches = num_branches self.multiscale_output = multiscale_output self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.with_cp = with_cp self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=False) def _check_branches(self, num_branches, num_blocks, in_channels, num_channels): """Check branches configuration.""" if num_branches != len(num_blocks): error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ f'{len(num_blocks)})' raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ f'{len(num_channels)})' raise ValueError(error_msg) if num_branches != len(in_channels): error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ f'{len(in_channels)})' raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): """Build one branch.""" downsample = None if stride != 1 or \ self.in_channels[branch_index] != \ num_channels[branch_index] * block.expansion: downsample = nn.Sequential( build_conv_layer( self.conv_cfg, self.in_channels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), build_norm_layer(self.norm_cfg, num_channels[branch_index] * block.expansion)[1]) layers = [] layers.append( block( self.in_channels[branch_index], num_channels[branch_index], stride, downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) self.in_channels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append( block( self.in_channels[branch_index], num_channels[branch_index], with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): """Build multiple branch.""" branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) def _make_fuse_layers(self): """Build fuse layer.""" if self.num_branches == 1: return None num_branches = self.num_branches in_channels = self.in_channels fuse_layers = [] num_out_branches = num_branches if self.multiscale_output else 1 for i in range(num_out_branches): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1], # we set align_corners=False for HRNet Upsample( scale_factor=2**(j - i), mode='bilinear', align_corners=False))) elif j == i: fuse_layer.append(None) else: conv_downsamples = [] for k in range(i - j): if k == i - j - 1: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1])) else: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], nn.ReLU(inplace=False))) fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def forward(self, x): """Forward function.""" if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = 0 for j in range(self.num_branches): if i == j: y += x[j] elif j > i: y = y + resize( self.fuse_layers[i][j](x[j]), size=x[i].shape[2:], mode='bilinear', align_corners=False) else: y += self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse @BACKBONES.register_module() class HRNet(nn.Module): """HRNet backbone. High-Resolution Representations for Labeling Pixels and Regions arXiv: https://arxiv.org/abs/1904.04514 Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Normally 3. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. Example: >>> from mmseg.models import HRNet >>> import torch >>> extra = dict( >>> stage1=dict( >>> num_modules=1, >>> num_branches=1, >>> block='BOTTLENECK', >>> num_blocks=(4, ), >>> num_channels=(64, )), >>> stage2=dict( >>> num_modules=1, >>> num_branches=2, >>> block='BASIC', >>> num_blocks=(4, 4), >>> num_channels=(32, 64)), >>> stage3=dict( >>> num_modules=4, >>> num_branches=3, >>> block='BASIC', >>> num_blocks=(4, 4, 4), >>> num_channels=(32, 64, 128)), >>> stage4=dict( >>> num_modules=3, >>> num_branches=4, >>> block='BASIC', >>> num_blocks=(4, 4, 4, 4), >>> num_channels=(32, 64, 128, 256))) >>> self = HRNet(extra, in_channels=1) >>> self.eval() >>> inputs = torch.rand(1, 1, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 32, 8, 8) (1, 64, 4, 4) (1, 128, 2, 2) (1, 256, 1, 1) """ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} def __init__(self, extra, in_channels=3, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, zero_init_residual=False): super(HRNet, self).__init__() self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual # stem net self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) self.conv1 = build_conv_layer( self.conv_cfg, in_channels, 64, kernel_size=3, stride=2, padding=1, bias=False) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) # stage 1 self.stage1_cfg = self.extra['stage1'] num_channels = self.stage1_cfg['num_channels'][0] block_type = self.stage1_cfg['block'] num_blocks = self.stage1_cfg['num_blocks'][0] block = self.blocks_dict[block_type] stage1_out_channels = num_channels * block.expansion self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) # stage 2 self.stage2_cfg = self.extra['stage2'] num_channels = self.stage2_cfg['num_channels'] block_type = self.stage2_cfg['block'] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] self.transition1 = self._make_transition_layer([stage1_out_channels], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) # stage 3 self.stage3_cfg = self.extra['stage3'] num_channels = self.stage3_cfg['num_channels'] block_type = self.stage3_cfg['block'] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) # stage 4 self.stage4_cfg = self.extra['stage4'] num_channels = self.stage4_cfg['num_channels'] block_type = self.stage4_cfg['block'] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels) @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: the normalization layer named "norm2" """ return getattr(self, self.norm2_name) def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[1], nn.ReLU(inplace=True))) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] out_channels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else in_channels conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) def _make_layer(self, block, inplanes, planes, blocks, stride=1): """Make each layer.""" downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( build_conv_layer( self.conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) layers = [] layers.append( block( inplanes, planes, stride, downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block( inplanes, planes, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) return nn.Sequential(*layers) def _make_stage(self, layer_config, in_channels, multiscale_output=True): """Make each stage.""" num_modules = layer_config['num_modules'] num_branches = layer_config['num_branches'] num_blocks = layer_config['num_blocks'] num_channels = layer_config['num_channels'] block = self.blocks_dict[layer_config['block']] hr_modules = [] for i in range(num_modules): # multi_scale_output is only used for the last module if not multiscale_output and i == num_modules - 1: reset_multiscale_output = False else: reset_multiscale_output = True hr_modules.append( HRModule( num_branches, block, num_blocks, in_channels, num_channels, reset_multiscale_output, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) return nn.Sequential(*hr_modules), in_channels def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Forward function.""" x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.conv2(x) x = self.norm2(x) x = self.relu(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_cfg['num_branches']): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_cfg['num_branches']): if self.transition2[i] is not None: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_cfg['num_branches']): if self.transition3[i] is not None: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage4(x_list) return y_list def train(self, mode=True): """Convert the model into training mode whill keeping the normalization layer freezed.""" super(HRNet, self).train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v2.py ================================================ import logging import torch.nn as nn from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from ..builder import BACKBONES from ..utils import InvertedResidual, make_divisible @BACKBONES.register_module() class MobileNetV2(nn.Module): """MobileNetV2 backbone. Args: widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Default: 1.0. strides (Sequence[int], optional): Strides of the first block of each layer. If not specified, default config in ``arch_setting`` will be used. dilations (Sequence[int]): Dilation of each layer. out_indices (None or Sequence[int]): Output from which stages. Default: (7, ). frozen_stages (int): Stages to be frozen (all param fixed). Default: -1, which means not freezing any parameters. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU6'). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ # Parameters to build layers. 3 parameters are needed to construct a # layer, from left to right: expand_ratio, channel, num_blocks. arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], [6, 96, 3], [6, 160, 3], [6, 320, 1]] def __init__(self, widen_factor=1., strides=(1, 2, 2, 2, 1, 2, 1), dilations=(1, 1, 1, 1, 1, 1, 1), out_indices=(1, 2, 4, 6), frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU6'), norm_eval=False, with_cp=False): super(MobileNetV2, self).__init__() self.widen_factor = widen_factor self.strides = strides self.dilations = dilations assert len(strides) == len(dilations) == len(self.arch_settings) self.out_indices = out_indices for index in out_indices: if index not in range(0, 7): raise ValueError('the item in out_indices must in ' f'range(0, 8). But received {index}') if frozen_stages not in range(-1, 7): raise ValueError('frozen_stages must be in range(-1, 7). ' f'But received {frozen_stages}') self.out_indices = out_indices self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.in_channels = make_divisible(32 * widen_factor, 8) self.conv1 = ConvModule( in_channels=3, out_channels=self.in_channels, kernel_size=3, stride=2, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.layers = [] for i, layer_cfg in enumerate(self.arch_settings): expand_ratio, channel, num_blocks = layer_cfg stride = self.strides[i] dilation = self.dilations[i] out_channels = make_divisible(channel * widen_factor, 8) inverted_res_layer = self.make_layer( out_channels=out_channels, num_blocks=num_blocks, stride=stride, dilation=dilation, expand_ratio=expand_ratio) layer_name = f'layer{i + 1}' self.add_module(layer_name, inverted_res_layer) self.layers.append(layer_name) def make_layer(self, out_channels, num_blocks, stride, dilation, expand_ratio): """Stack InvertedResidual blocks to build a layer for MobileNetV2. Args: out_channels (int): out_channels of block. num_blocks (int): Number of blocks. stride (int): Stride of the first block. dilation (int): Dilation of the first block. expand_ratio (int): Expand the number of channels of the hidden layer in InvertedResidual by this ratio. """ layers = [] for i in range(num_blocks): layers.append( InvertedResidual( self.in_channels, out_channels, stride if i == 0 else 1, expand_ratio=expand_ratio, dilation=dilation if i == 0 else 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, with_cp=self.with_cp)) self.in_channels = out_channels return nn.Sequential(*layers) def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') def forward(self, x): x = self.conv1(x) outs = [] for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] else: return tuple(outs) def _freeze_stages(self): if self.frozen_stages >= 0: for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): layer = getattr(self, f'layer{i}') layer.eval() for param in layer.parameters(): param.requires_grad = False def train(self, mode=True): super(MobileNetV2, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval() ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/mobilenet_v3.py ================================================ import logging import mmcv import torch.nn as nn from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.cnn.bricks import Conv2dAdaptivePadding from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from ..builder import BACKBONES from ..utils import InvertedResidualV3 as InvertedResidual @BACKBONES.register_module() class MobileNetV3(nn.Module): """MobileNetV3 backbone. This backbone is the improved implementation of `Searching for MobileNetV3 `_. Args: arch (str): Architechture of mobilnetv3, from {'small', 'large'}. Default: 'small'. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). out_indices (tuple[int]): Output from which layer. Default: (0, 1, 12). frozen_stages (int): Stages to be frozen (all param fixed). Defualt: -1, which means not freezing any parameters. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defualt: False. """ # Parameters to build each block: # [kernel size, mid channels, out channels, with_se, act type, stride] arch_settings = { 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 [3, 88, 24, False, 'ReLU', 1], [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 [5, 240, 40, True, 'HSwish', 1], [5, 240, 40, True, 'HSwish', 1], [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 [5, 144, 48, True, 'HSwish', 1], [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 [5, 576, 96, True, 'HSwish', 1], [5, 576, 96, True, 'HSwish', 1]], 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 [3, 72, 24, False, 'ReLU', 1], [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 [5, 120, 40, True, 'ReLU', 1], [5, 120, 40, True, 'ReLU', 1], [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 [3, 200, 80, False, 'HSwish', 1], [3, 184, 80, False, 'HSwish', 1], [3, 184, 80, False, 'HSwish', 1], [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 [3, 672, 112, True, 'HSwish', 1], [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 [5, 960, 160, True, 'HSwish', 1], [5, 960, 160, True, 'HSwish', 1]] } # yapf: disable def __init__(self, arch='small', conv_cfg=None, norm_cfg=dict(type='BN'), out_indices=(0, 1, 12), frozen_stages=-1, reduction_factor=1, norm_eval=False, with_cp=False): super(MobileNetV3, self).__init__() assert arch in self.arch_settings assert isinstance(reduction_factor, int) and reduction_factor > 0 assert mmcv.is_tuple_of(out_indices, int) for index in out_indices: if index not in range(0, len(self.arch_settings[arch]) + 2): raise ValueError( 'the item in out_indices must in ' f'range(0, {len(self.arch_settings[arch])+2}). ' f'But received {index}') if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): raise ValueError('frozen_stages must be in range(-1, ' f'{len(self.arch_settings[arch])+2}). ' f'But received {frozen_stages}') self.arch = arch self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.out_indices = out_indices self.frozen_stages = frozen_stages self.reduction_factor = reduction_factor self.norm_eval = norm_eval self.with_cp = with_cp self.layers = self._make_layer() def _make_layer(self): layers = [] # build the first layer (layer0) in_channels = 16 layer = ConvModule( in_channels=3, out_channels=in_channels, kernel_size=3, stride=2, padding=1, conv_cfg=dict(type='Conv2dAdaptivePadding'), norm_cfg=self.norm_cfg, act_cfg=dict(type='HSwish')) self.add_module('layer0', layer) layers.append('layer0') layer_setting = self.arch_settings[self.arch] for i, params in enumerate(layer_setting): (kernel_size, mid_channels, out_channels, with_se, act, stride) = params if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ i >= 8: mid_channels = mid_channels // self.reduction_factor out_channels = out_channels // self.reduction_factor if with_se: se_cfg = dict( channels=mid_channels, ratio=4, act_cfg=(dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, divisor=6.0))) else: se_cfg = None layer = InvertedResidual( in_channels=in_channels, out_channels=out_channels, mid_channels=mid_channels, kernel_size=kernel_size, stride=stride, se_cfg=se_cfg, with_expand_conv=(in_channels != mid_channels), conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type=act), with_cp=self.with_cp) in_channels = out_channels layer_name = 'layer{}'.format(i + 1) self.add_module(layer_name, layer) layers.append(layer_name) # build the last layer # block5 layer12 os=32 for small model # block6 layer16 os=32 for large model layer = ConvModule( in_channels=in_channels, out_channels=576 if self.arch == 'small' else 960, kernel_size=1, stride=1, dilation=4, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='HSwish')) layer_name = 'layer{}'.format(len(layer_setting) + 1) self.add_module(layer_name, layer) layers.append(layer_name) # next, convert backbone MobileNetV3 to a semantic segmentation version if self.arch == 'small': self.layer4.depthwise_conv.conv.stride = (1, 1) self.layer9.depthwise_conv.conv.stride = (1, 1) for i in range(4, len(layers)): layer = getattr(self, layers[i]) if isinstance(layer, InvertedResidual): modified_module = layer.depthwise_conv.conv else: modified_module = layer.conv if i < 9: modified_module.dilation = (2, 2) pad = 2 else: modified_module.dilation = (4, 4) pad = 4 if not isinstance(modified_module, Conv2dAdaptivePadding): # Adjust padding pad *= (modified_module.kernel_size[0] - 1) // 2 modified_module.padding = (pad, pad) else: self.layer7.depthwise_conv.conv.stride = (1, 1) self.layer13.depthwise_conv.conv.stride = (1, 1) for i in range(7, len(layers)): layer = getattr(self, layers[i]) if isinstance(layer, InvertedResidual): modified_module = layer.depthwise_conv.conv else: modified_module = layer.conv if i < 13: modified_module.dilation = (2, 2) pad = 2 else: modified_module.dilation = (4, 4) pad = 4 if not isinstance(modified_module, Conv2dAdaptivePadding): # Adjust padding pad *= (modified_module.kernel_size[0] - 1) // 2 modified_module.padding = (pad, pad) return layers def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') def forward(self, x): outs = [] for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) if i in self.out_indices: outs.append(x) return outs def _freeze_stages(self): for i in range(self.frozen_stages + 1): layer = getattr(self, f'layer{i}') layer.eval() for param in layer.parameters(): param.requires_grad = False def train(self, mode=True): super(MobileNetV3, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval() ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnest.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn import build_conv_layer, build_norm_layer from ..builder import BACKBONES from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck from .resnet import ResNetV1d class RSoftmax(nn.Module): """Radix Softmax module in ``SplitAttentionConv2d``. Args: radix (int): Radix of input. groups (int): Groups of input. """ def __init__(self, radix, groups): super().__init__() self.radix = radix self.groups = groups def forward(self, x): batch = x.size(0) if self.radix > 1: x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) x = F.softmax(x, dim=1) x = x.reshape(batch, -1) else: x = torch.sigmoid(x) return x class SplitAttentionConv2d(nn.Module): """Split-Attention Conv2d in ResNeSt. Args: in_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d. kernel_size (int | tuple[int]): Same as nn.Conv2d. stride (int | tuple[int]): Same as nn.Conv2d. padding (int | tuple[int]): Same as nn.Conv2d. dilation (int | tuple[int]): Same as nn.Conv2d. groups (int): Same as nn.Conv2d. radix (int): Radix of SpltAtConv2d. Default: 2 reduction_factor (int): Reduction factor of inter_channels. Default: 4. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: None. dcn (dict): Config dict for DCN. Default: None. """ def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, radix=2, reduction_factor=4, conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None): super(SplitAttentionConv2d, self).__init__() inter_channels = max(in_channels * radix // reduction_factor, 32) self.radix = radix self.groups = groups self.channels = channels self.with_dcn = dcn is not None self.dcn = dcn fallback_on_stride = False if self.with_dcn: fallback_on_stride = self.dcn.pop('fallback_on_stride', False) if self.with_dcn and not fallback_on_stride: assert conv_cfg is None, 'conv_cfg must be None for DCN' conv_cfg = dcn self.conv = build_conv_layer( conv_cfg, in_channels, channels * radix, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups * radix, bias=False) self.norm0_name, norm0 = build_norm_layer( norm_cfg, channels * radix, postfix=0) self.add_module(self.norm0_name, norm0) self.relu = nn.ReLU(inplace=True) self.fc1 = build_conv_layer( None, channels, inter_channels, 1, groups=self.groups) self.norm1_name, norm1 = build_norm_layer( norm_cfg, inter_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.fc2 = build_conv_layer( None, inter_channels, channels * radix, 1, groups=self.groups) self.rsoftmax = RSoftmax(radix, groups) @property def norm0(self): """nn.Module: the normalization layer named "norm0" """ return getattr(self, self.norm0_name) @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) def forward(self, x): x = self.conv(x) x = self.norm0(x) x = self.relu(x) batch, rchannel = x.shape[:2] batch = x.size(0) if self.radix > 1: splits = x.view(batch, self.radix, -1, *x.shape[2:]) gap = splits.sum(dim=1) else: gap = x gap = F.adaptive_avg_pool2d(gap, 1) gap = self.fc1(gap) gap = self.norm1(gap) gap = self.relu(gap) atten = self.fc2(gap) atten = self.rsoftmax(atten).view(batch, -1, 1, 1) if self.radix > 1: attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) out = torch.sum(attens * splits, dim=1) else: out = atten * x return out.contiguous() class Bottleneck(_Bottleneck): """Bottleneck block for ResNeSt. Args: inplane (int): Input planes of this block. planes (int): Middle planes of this block. groups (int): Groups of conv2. width_per_group (int): Width per group of conv2. 64x4d indicates ``groups=64, width_per_group=4`` and 32x8d indicates ``groups=32, width_per_group=8``. radix (int): Radix of SpltAtConv2d. Default: 2 reduction_factor (int): Reduction factor of inter_channels in SplitAttentionConv2d. Default: 4. avg_down_stride (bool): Whether to use average pool for stride in Bottleneck. Default: True. kwargs (dict): Key word arguments for base class. """ expansion = 4 def __init__(self, inplanes, planes, groups=1, base_width=4, base_channels=64, radix=2, reduction_factor=4, avg_down_stride=True, **kwargs): """Bottleneck block for ResNeSt.""" super(Bottleneck, self).__init__(inplanes, planes, **kwargs) if groups == 1: width = self.planes else: width = math.floor(self.planes * (base_width / base_channels)) * groups self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, width, postfix=1) self.norm3_name, norm3 = build_norm_layer( self.norm_cfg, self.planes * self.expansion, postfix=3) self.conv1 = build_conv_layer( self.conv_cfg, self.inplanes, width, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) self.with_modulated_dcn = False self.conv2 = SplitAttentionConv2d( width, width, kernel_size=3, stride=1 if self.avg_down_stride else self.conv2_stride, padding=self.dilation, dilation=self.dilation, groups=groups, radix=radix, reduction_factor=reduction_factor, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, dcn=self.dcn) delattr(self, self.norm2_name) if self.avg_down_stride: self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) self.conv3 = build_conv_layer( self.conv_cfg, width, self.planes * self.expansion, kernel_size=1, bias=False) self.add_module(self.norm3_name, norm3) def forward(self, x): def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv1_plugin_names) out = self.conv2(out) if self.avg_down_stride: out = self.avd_layer(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv2_plugin_names) out = self.conv3(out) out = self.norm3(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv3_plugin_names) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out @BACKBONES.register_module() class ResNeSt(ResNetV1d): """ResNeSt backbone. Args: groups (int): Number of groups of Bottleneck. Default: 1 base_width (int): Base width of Bottleneck. Default: 4 radix (int): Radix of SpltAtConv2d. Default: 2 reduction_factor (int): Reduction factor of inter_channels in SplitAttentionConv2d. Default: 4. avg_down_stride (bool): Whether to use average pool for stride in Bottleneck. Default: True. kwargs (dict): Keyword arguments for ResNet. """ arch_settings = { 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)), 200: (Bottleneck, (3, 24, 36, 3)) } def __init__(self, groups=1, base_width=4, radix=2, reduction_factor=4, avg_down_stride=True, **kwargs): self.groups = groups self.base_width = base_width self.radix = radix self.reduction_factor = reduction_factor self.avg_down_stride = avg_down_stride super(ResNeSt, self).__init__(**kwargs) def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" return ResLayer( groups=self.groups, base_width=self.base_width, base_channels=self.base_channels, radix=self.radix, reduction_factor=self.reduction_factor, avg_down_stride=self.avg_down_stride, **kwargs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnet.py ================================================ import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, constant_init, kaiming_init) from mmcv.runner import load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.utils import get_root_logger from ..builder import BACKBONES from ..utils import ResLayer class BasicBlock(nn.Module): """Basic block for ResNet.""" expansion = 1 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, plugins=None): super(BasicBlock, self).__init__() assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.conv1 = build_conv_layer( conv_cfg, inplanes, planes, 3, stride=stride, padding=dilation, dilation=dilation, bias=False) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( conv_cfg, planes, planes, 3, padding=1, bias=False) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation self.with_cp = with_cp @property def norm1(self): """nn.Module: normalization layer after the first convolution layer""" return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: normalization layer after the second convolution layer""" return getattr(self, self.norm2_name) def forward(self, x): """Forward function.""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) out = self.conv2(out) out = self.norm2(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out class Bottleneck(nn.Module): """Bottleneck block for ResNet. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is "caffe", the stride-two layer is the first 1x1 conv layer. """ expansion = 4 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, plugins=None): super(Bottleneck, self).__init__() assert style in ['pytorch', 'caffe'] assert dcn is None or isinstance(dcn, dict) assert plugins is None or isinstance(plugins, list) if plugins is not None: allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] assert all(p['position'] in allowed_position for p in plugins) self.inplanes = inplanes self.planes = planes self.stride = stride self.dilation = dilation self.style = style self.with_cp = with_cp self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.dcn = dcn self.with_dcn = dcn is not None self.plugins = plugins self.with_plugins = plugins is not None if self.with_plugins: # collect plugins for conv1/conv2/conv3 self.after_conv1_plugins = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_conv1' ] self.after_conv2_plugins = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_conv2' ] self.after_conv3_plugins = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_conv3' ] if self.style == 'pytorch': self.conv1_stride = 1 self.conv2_stride = stride else: self.conv1_stride = stride self.conv2_stride = 1 self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.norm3_name, norm3 = build_norm_layer( norm_cfg, planes * self.expansion, postfix=3) self.conv1 = build_conv_layer( conv_cfg, inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) fallback_on_stride = False if self.with_dcn: fallback_on_stride = dcn.pop('fallback_on_stride', False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( conv_cfg, planes, planes, kernel_size=3, stride=self.conv2_stride, padding=dilation, dilation=dilation, bias=False) else: assert self.conv_cfg is None, 'conv_cfg must be None for DCN' self.conv2 = build_conv_layer( dcn, planes, planes, kernel_size=3, stride=self.conv2_stride, padding=dilation, dilation=dilation, bias=False) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( conv_cfg, planes, planes * self.expansion, kernel_size=1, bias=False) self.add_module(self.norm3_name, norm3) self.relu = nn.ReLU(inplace=True) self.downsample = downsample if self.with_plugins: self.after_conv1_plugin_names = self.make_block_plugins( planes, self.after_conv1_plugins) self.after_conv2_plugin_names = self.make_block_plugins( planes, self.after_conv2_plugins) self.after_conv3_plugin_names = self.make_block_plugins( planes * self.expansion, self.after_conv3_plugins) def make_block_plugins(self, in_channels, plugins): """make plugins for block. Args: in_channels (int): Input channels of plugin. plugins (list[dict]): List of plugins cfg to build. Returns: list[str]: List of the names of plugin. """ assert isinstance(plugins, list) plugin_names = [] for plugin in plugins: plugin = plugin.copy() name, layer = build_plugin_layer( plugin, in_channels=in_channels, postfix=plugin.pop('postfix', '')) assert not hasattr(self, name), f'duplicate plugin {name}' self.add_module(name, layer) plugin_names.append(name) return plugin_names def forward_plugin(self, x, plugin_names): """Forward function for plugins.""" out = x for name in plugin_names: out = getattr(self, name)(x) return out @property def norm1(self): """nn.Module: normalization layer after the first convolution layer""" return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: normalization layer after the second convolution layer""" return getattr(self, self.norm2_name) @property def norm3(self): """nn.Module: normalization layer after the third convolution layer""" return getattr(self, self.norm3_name) def forward(self, x): """Forward function.""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv1_plugin_names) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv2_plugin_names) out = self.conv3(out) out = self.norm3(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv3_plugin_names) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out @BACKBONES.register_module() class ResNet(nn.Module): """ResNet backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Default" 3. stem_channels (int): Number of stem channels. Default: 64. base_channels (int): Number of base channels of res layer. Default: 64. num_stages (int): Resnet stages, normally 4. strides (Sequence[int]): Strides of the first block of each stage. dilations (Sequence[int]): Dilation of each stage. out_indices (Sequence[int]): Output from which stages. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. norm_cfg (dict): Dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. plugins (list[dict]): List of plugins for stages, each dict contains: - cfg (dict, required): Cfg dict to build plugin. - position (str, required): Position inside block to insert plugin, options: 'after_conv1', 'after_conv2', 'after_conv3'. - stages (tuple[bool], optional): Stages to apply plugin, length should be same as 'num_stages' multi_grid (Sequence[int]|None): Multi grid dilation rates of last stage. Default: None contract_dilation (bool): Whether contract first dilation of each layer Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Example: >>> from mmseg.models import ResNet >>> import torch >>> self = ResNet(depth=18) >>> self.eval() >>> inputs = torch.rand(1, 3, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 64, 8, 8) (1, 128, 4, 4) (1, 256, 2, 2) (1, 512, 1, 1) """ arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, depth, in_channels=3, stem_channels=64, base_channels=64, num_stages=4, strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(0, 1, 2, 3), style='pytorch', deep_stem=False, avg_down=False, frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, dcn=None, stage_with_dcn=(False, False, False, False), plugins=None, multi_grid=None, contract_dilation=False, with_cp=False, zero_init_residual=True): super(ResNet, self).__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.stem_channels = stem_channels self.base_channels = base_channels self.num_stages = num_stages assert num_stages >= 1 and num_stages <= 4 self.strides = strides self.dilations = dilations assert len(strides) == len(dilations) == num_stages self.out_indices = out_indices assert max(out_indices) < num_stages self.style = style self.deep_stem = deep_stem self.avg_down = avg_down self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_cp = with_cp self.norm_eval = norm_eval self.dcn = dcn self.stage_with_dcn = stage_with_dcn if dcn is not None: assert len(stage_with_dcn) == num_stages self.plugins = plugins self.multi_grid = multi_grid self.contract_dilation = contract_dilation self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.inplanes = stem_channels self._make_stem_layer(in_channels, stem_channels) self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): stride = strides[i] dilation = dilations[i] dcn = self.dcn if self.stage_with_dcn[i] else None if plugins is not None: stage_plugins = self.make_stage_plugins(plugins, i) else: stage_plugins = None # multi grid is applied to last layer only stage_multi_grid = multi_grid if i == len( self.stage_blocks) - 1 else None planes = base_channels * 2**i res_layer = self.make_res_layer( block=self.block, inplanes=self.inplanes, planes=planes, num_blocks=num_blocks, stride=stride, dilation=dilation, style=self.style, avg_down=self.avg_down, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, dcn=dcn, plugins=stage_plugins, multi_grid=stage_multi_grid, contract_dilation=contract_dilation) self.inplanes = planes * self.block.expansion layer_name = f'layer{i+1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self._freeze_stages() self.feat_dim = self.block.expansion * base_channels * 2**( len(self.stage_blocks) - 1) def make_stage_plugins(self, plugins, stage_idx): """make plugins for ResNet 'stage_idx'th stage . Currently we support to insert 'context_block', 'empirical_attention_block', 'nonlocal_block' into the backbone like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of Bottleneck. An example of plugins format could be : >>> plugins=[ ... dict(cfg=dict(type='xxx', arg1='xxx'), ... stages=(False, True, True, True), ... position='after_conv2'), ... dict(cfg=dict(type='yyy'), ... stages=(True, True, True, True), ... position='after_conv3'), ... dict(cfg=dict(type='zzz', postfix='1'), ... stages=(True, True, True, True), ... position='after_conv3'), ... dict(cfg=dict(type='zzz', postfix='2'), ... stages=(True, True, True, True), ... position='after_conv3') ... ] >>> self = ResNet(depth=18) >>> stage_plugins = self.make_stage_plugins(plugins, 0) >>> assert len(stage_plugins) == 3 Suppose 'stage_idx=0', the structure of blocks in the stage would be: conv1-> conv2->conv3->yyy->zzz1->zzz2 Suppose 'stage_idx=1', the structure of blocks in the stage would be: conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 If stages is missing, the plugin would be applied to all stages. Args: plugins (list[dict]): List of plugins cfg to build. The postfix is required if multiple same type plugins are inserted. stage_idx (int): Index of stage to build Returns: list[dict]: Plugins for current stage """ stage_plugins = [] for plugin in plugins: plugin = plugin.copy() stages = plugin.pop('stages', None) assert stages is None or len(stages) == self.num_stages # whether to insert plugin into current stage if stages is None or stages[stage_idx]: stage_plugins.append(plugin) return stage_plugins def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" return ResLayer(**kwargs) @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) def _make_stem_layer(self, in_channels, stem_channels): """Make stem layer for ResNet.""" if self.deep_stem: self.stem = nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels // 2, kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, stem_channels)[1], nn.ReLU(inplace=True)) else: self.conv1 = build_conv_layer( self.conv_cfg, in_channels, stem_channels, kernel_size=7, stride=2, padding=3, bias=False) self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, stem_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def _freeze_stages(self): """Freeze stages param and norm stats.""" if self.frozen_stages >= 0: if self.deep_stem: self.stem.eval() for param in self.stem.parameters(): param.requires_grad = False else: self.norm1.eval() for m in [self.conv1, self.norm1]: for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.dcn is not None: for m in self.modules(): if isinstance(m, Bottleneck) and hasattr( m, 'conv2_offset'): constant_init(m.conv2_offset, 0) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Forward function.""" if self.deep_stem: x = self.stem(x) else: x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) return tuple(outs) def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" super(ResNet, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() @BACKBONES.register_module() class ResNetV1c(ResNet): """ResNetV1c variant described in [1]_. Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in the input stem with three 3x3 convs. References: .. [1] https://arxiv.org/pdf/1812.01187.pdf """ def __init__(self, **kwargs): super(ResNetV1c, self).__init__( deep_stem=True, avg_down=False, **kwargs) @BACKBONES.register_module() class ResNetV1d(ResNet): """ResNetV1d variant described in [1]_. Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in the input stem with three 3x3 convs. And in the downsampling block, a 2x2 avg_pool with stride 2 is added before conv, whose stride is changed to 1. """ def __init__(self, **kwargs): super(ResNetV1d, self).__init__( deep_stem=True, avg_down=True, **kwargs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/resnext.py ================================================ import math from mmcv.cnn import build_conv_layer, build_norm_layer from ..builder import BACKBONES from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck from .resnet import ResNet class Bottleneck(_Bottleneck): """Bottleneck block for ResNeXt. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is "caffe", the stride-two layer is the first 1x1 conv layer. """ def __init__(self, inplanes, planes, groups=1, base_width=4, base_channels=64, **kwargs): super(Bottleneck, self).__init__(inplanes, planes, **kwargs) if groups == 1: width = self.planes else: width = math.floor(self.planes * (base_width / base_channels)) * groups self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, width, postfix=1) self.norm2_name, norm2 = build_norm_layer( self.norm_cfg, width, postfix=2) self.norm3_name, norm3 = build_norm_layer( self.norm_cfg, self.planes * self.expansion, postfix=3) self.conv1 = build_conv_layer( self.conv_cfg, self.inplanes, width, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) fallback_on_stride = False self.with_modulated_dcn = False if self.with_dcn: fallback_on_stride = self.dcn.pop('fallback_on_stride', False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( self.conv_cfg, width, width, kernel_size=3, stride=self.conv2_stride, padding=self.dilation, dilation=self.dilation, groups=groups, bias=False) else: assert self.conv_cfg is None, 'conv_cfg must be None for DCN' self.conv2 = build_conv_layer( self.dcn, width, width, kernel_size=3, stride=self.conv2_stride, padding=self.dilation, dilation=self.dilation, groups=groups, bias=False) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( self.conv_cfg, width, self.planes * self.expansion, kernel_size=1, bias=False) self.add_module(self.norm3_name, norm3) @BACKBONES.register_module() class ResNeXt(ResNet): """ResNeXt backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Normally 3. num_stages (int): Resnet stages, normally 4. groups (int): Group of resnext. base_width (int): Base width of resnext. strides (Sequence[int]): Strides of the first block of each stage. dilations (Sequence[int]): Dilation of each stage. out_indices (Sequence[int]): Output from which stages. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. Example: >>> from mmseg.models import ResNeXt >>> import torch >>> self = ResNeXt(depth=50) >>> self.eval() >>> inputs = torch.rand(1, 3, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 256, 8, 8) (1, 512, 4, 4) (1, 1024, 2, 2) (1, 2048, 1, 1) """ arch_settings = { 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, groups=1, base_width=4, **kwargs): self.groups = groups self.base_width = base_width super(ResNeXt, self).__init__(**kwargs) def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``""" return ResLayer( groups=self.groups, base_width=self.base_width, base_channels=self.base_channels, **kwargs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/backbones/unet.py ================================================ import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, build_norm_layer, constant_init, kaiming_init) from mmcv.runner import load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.utils import get_root_logger from ..builder import BACKBONES from ..utils import UpConvBlock class BasicConvBlock(nn.Module): """Basic convolutional block for UNet. This module consists of several plain convolutional layers. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. num_convs (int): Number of convolutional layers. Default: 2. stride (int): Whether use stride convolution to downsample the input feature map. If stride=2, it only uses stride convolution in the first convolutional layer to downsample the input feature map. Options are 1 or 2. Default: 1. dilation (int): Whether use dilated convolution to expand the receptive field. Set dilation rate of each convolutional layer and the dilation rate of the first convolutional layer is always 1. Default: 1. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. conv_cfg (dict | None): Config dict for convolution layer. Default: None. norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict | None): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). dcn (bool): Use deformable convoluton in convolutional layer or not. Default: None. plugins (dict): plugins for convolutional layers. Default: None. """ def __init__(self, in_channels, out_channels, num_convs=2, stride=1, dilation=1, with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), dcn=None, plugins=None): super(BasicConvBlock, self).__init__() assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' self.with_cp = with_cp convs = [] for i in range(num_convs): convs.append( ConvModule( in_channels=in_channels if i == 0 else out_channels, out_channels=out_channels, kernel_size=3, stride=stride if i == 0 else 1, dilation=1 if i == 0 else dilation, padding=1 if i == 0 else dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.convs = nn.Sequential(*convs) def forward(self, x): """Forward function.""" if self.with_cp and x.requires_grad: out = cp.checkpoint(self.convs, x) else: out = self.convs(x) return out @UPSAMPLE_LAYERS.register_module() class DeconvModule(nn.Module): """Deconvolution upsample module in decoder for UNet (2X upsample). This module uses deconvolution to upsample feature map in the decoder of UNet. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict | None): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). kernel_size (int): Kernel size of the convolutional layer. Default: 4. """ def __init__(self, in_channels, out_channels, with_cp=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), *, kernel_size=4, scale_factor=2): super(DeconvModule, self).__init__() assert (kernel_size - scale_factor >= 0) and\ (kernel_size - scale_factor) % 2 == 0,\ f'kernel_size should be greater than or equal to scale_factor '\ f'and (kernel_size - scale_factor) should be even numbers, '\ f'while the kernel size is {kernel_size} and scale_factor is '\ f'{scale_factor}.' stride = scale_factor padding = (kernel_size - scale_factor) // 2 self.with_cp = with_cp deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) norm_name, norm = build_norm_layer(norm_cfg, out_channels) activate = build_activation_layer(act_cfg) self.deconv_upsamping = nn.Sequential(deconv, norm, activate) def forward(self, x): """Forward function.""" if self.with_cp and x.requires_grad: out = cp.checkpoint(self.deconv_upsamping, x) else: out = self.deconv_upsamping(x) return out @UPSAMPLE_LAYERS.register_module() class InterpConv(nn.Module): """Interpolation upsample module in decoder for UNet. This module uses interpolation to upsample feature map in the decoder of UNet. It consists of one interpolation upsample layer and one convolutional layer. It can be one interpolation upsample layer followed by one convolutional layer (conv_first=False) or one convolutional layer followed by one interpolation upsample layer (conv_first=True). Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict | None): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). conv_cfg (dict | None): Config dict for convolution layer. Default: None. conv_first (bool): Whether convolutional layer or interpolation upsample layer first. Default: False. It means interpolation upsample layer followed by one convolutional layer. kernel_size (int): Kernel size of the convolutional layer. Default: 1. stride (int): Stride of the convolutional layer. Default: 1. padding (int): Padding of the convolutional layer. Default: 1. upsampe_cfg (dict): Interpolation config of the upsample layer. Default: dict( scale_factor=2, mode='bilinear', align_corners=False). """ def __init__(self, in_channels, out_channels, with_cp=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), *, conv_cfg=None, conv_first=False, kernel_size=1, stride=1, padding=0, upsampe_cfg=dict( scale_factor=2, mode='bilinear', align_corners=False)): super(InterpConv, self).__init__() self.with_cp = with_cp conv = ConvModule( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) upsample = nn.Upsample(**upsampe_cfg) if conv_first: self.interp_upsample = nn.Sequential(conv, upsample) else: self.interp_upsample = nn.Sequential(upsample, conv) def forward(self, x): """Forward function.""" if self.with_cp and x.requires_grad: out = cp.checkpoint(self.interp_upsample, x) else: out = self.interp_upsample(x) return out @BACKBONES.register_module() class UNet(nn.Module): """UNet backbone. U-Net: Convolutional Networks for Biomedical Image Segmentation. https://arxiv.org/pdf/1505.04597.pdf Args: in_channels (int): Number of input image channels. Default" 3. base_channels (int): Number of base channels of each stage. The output channels of the first stage. Default: 64. num_stages (int): Number of stages in encoder, normally 5. Default: 5. strides (Sequence[int 1 | 2]): Strides of each stage in encoder. len(strides) is equal to num_stages. Normally the stride of the first stage in encoder is 1. If strides[i]=2, it uses stride convolution to downsample in the correspondance encoder stage. Default: (1, 1, 1, 1, 1). enc_num_convs (Sequence[int]): Number of convolutional layers in the convolution block of the correspondance encoder stage. Default: (2, 2, 2, 2, 2). dec_num_convs (Sequence[int]): Number of convolutional layers in the convolution block of the correspondance decoder stage. Default: (2, 2, 2, 2). downsamples (Sequence[int]): Whether use MaxPool to downsample the feature map after the first stage of encoder (stages: [1, num_stages)). If the correspondance encoder stage use stride convolution (strides[i]=2), it will never use MaxPool to downsample, even downsamples[i-1]=True. Default: (True, True, True, True). enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. Default: (1, 1, 1, 1, 1). dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. Default: (1, 1, 1, 1). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. conv_cfg (dict | None): Config dict for convolution layer. Default: None. norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict | None): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). upsample_cfg (dict): The upsample config of the upsample module in decoder. Default: dict(type='InterpConv'). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. dcn (bool): Use deformable convoluton in convolutional layer or not. Default: None. plugins (dict): plugins for convolutional layers. Default: None. Notice: The input image size should be devisible by the whole downsample rate of the encoder. More detail of the whole downsample rate can be found in UNet._check_input_devisible. """ def __init__(self, in_channels=3, base_channels=64, num_stages=5, strides=(1, 1, 1, 1, 1), enc_num_convs=(2, 2, 2, 2, 2), dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='InterpConv'), norm_eval=False, dcn=None, plugins=None): super(UNet, self).__init__() assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' assert len(strides) == num_stages, \ 'The length of strides should be equal to num_stages, '\ f'while the strides is {strides}, the length of '\ f'strides is {len(strides)}, and the num_stages is '\ f'{num_stages}.' assert len(enc_num_convs) == num_stages, \ 'The length of enc_num_convs should be equal to num_stages, '\ f'while the enc_num_convs is {enc_num_convs}, the length of '\ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ f'{num_stages}.' assert len(dec_num_convs) == (num_stages-1), \ 'The length of dec_num_convs should be equal to (num_stages-1), '\ f'while the dec_num_convs is {dec_num_convs}, the length of '\ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ f'{num_stages}.' assert len(downsamples) == (num_stages-1), \ 'The length of downsamples should be equal to (num_stages-1), '\ f'while the downsamples is {downsamples}, the length of '\ f'downsamples is {len(downsamples)}, and the num_stages is '\ f'{num_stages}.' assert len(enc_dilations) == num_stages, \ 'The length of enc_dilations should be equal to num_stages, '\ f'while the enc_dilations is {enc_dilations}, the length of '\ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ f'{num_stages}.' assert len(dec_dilations) == (num_stages-1), \ 'The length of dec_dilations should be equal to (num_stages-1), '\ f'while the dec_dilations is {dec_dilations}, the length of '\ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ f'{num_stages}.' self.num_stages = num_stages self.strides = strides self.downsamples = downsamples self.norm_eval = norm_eval self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for i in range(num_stages): enc_conv_block = [] if i != 0: if strides[i] == 1 and downsamples[i - 1]: enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) upsample = (strides[i] != 1 or downsamples[i - 1]) self.decoder.append( UpConvBlock( conv_block=BasicConvBlock, in_channels=base_channels * 2**i, skip_channels=base_channels * 2**(i - 1), out_channels=base_channels * 2**(i - 1), num_convs=dec_num_convs[i - 1], stride=1, dilation=dec_dilations[i - 1], with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, upsample_cfg=upsample_cfg if upsample else None, dcn=None, plugins=None)) enc_conv_block.append( BasicConvBlock( in_channels=in_channels, out_channels=base_channels * 2**i, num_convs=enc_num_convs[i], stride=strides[i], dilation=enc_dilations[i], with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, dcn=None, plugins=None)) self.encoder.append((nn.Sequential(*enc_conv_block))) in_channels = base_channels * 2**i def forward(self, x): self._check_input_devisible(x) enc_outs = [] for enc in self.encoder: x = enc(x) enc_outs.append(x) dec_outs = [x] for i in reversed(range(len(self.decoder))): x = self.decoder[i](enc_outs[i], x) dec_outs.append(x) return dec_outs def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" super(UNet, self).train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() def _check_input_devisible(self, x): h, w = x.shape[-2:] whole_downsample_rate = 1 for i in range(1, self.num_stages): if self.strides[i] == 2 or self.downsamples[i - 1]: whole_downsample_rate *= 2 assert (h % whole_downsample_rate == 0) \ and (w % whole_downsample_rate == 0),\ f'The input image size {(h, w)} should be devisible by the whole '\ f'downsample rate {whole_downsample_rate}, when num_stages is '\ f'{self.num_stages}, strides is {self.strides}, and downsamples '\ f'is {self.downsamples}.' def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/builder.py ================================================ import warnings from mmcv.utils import Registry, build_from_cfg from torch import nn BACKBONES = Registry('backbone') NECKS = Registry('neck') HEADS = Registry('head') LOSSES = Registry('loss') SEGMENTORS = Registry('segmentor') def build(cfg, registry, default_args=None): """Build a module. Args: cfg (dict, list[dict]): The config of modules, is is either a dict or a list of configs. registry (:obj:`Registry`): A registry the module belongs to. default_args (dict, optional): Default arguments to build the module. Defaults to None. Returns: nn.Module: A built nn module. """ if isinstance(cfg, list): modules = [ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg ] return nn.Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) def build_backbone(cfg): """Build backbone.""" return build(cfg, BACKBONES) def build_neck(cfg): """Build neck.""" return build(cfg, NECKS) def build_head(cfg): """Build head.""" return build(cfg, HEADS) def build_loss(cfg): """Build loss.""" return build(cfg, LOSSES) def build_segmentor(cfg, train_cfg=None, test_cfg=None): """Build segmentor.""" if train_cfg is not None or test_cfg is not None: warnings.warn( 'train_cfg and test_cfg is deprecated, ' 'please specify them in model', UserWarning) assert cfg.get('train_cfg') is None or train_cfg is None, \ 'train_cfg specified in both outer field and model field ' assert cfg.get('test_cfg') is None or test_cfg is None, \ 'test_cfg specified in both outer field and model field ' return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/__init__.py ================================================ from .ann_head import ANNHead from .apc_head import APCHead from .aspp_head import ASPPHead from .cc_head import CCHead from .da_head import DAHead from .dm_head import DMHead from .dnl_head import DNLHead from .ema_head import EMAHead from .enc_head import EncHead from .fcn_head import FCNHead from .fpn_head import FPNHead from .gc_head import GCHead from .lraspp_head import LRASPPHead from .nl_head import NLHead from .ocr_head import OCRHead from .point_head import PointHead from .psa_head import PSAHead from .psp_head import PSPHead from .sep_aspp_head import DepthwiseSeparableASPPHead from .sep_fcn_head import DepthwiseSeparableFCNHead from .uper_head import UPerHead __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ann_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule from ..builder import HEADS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead class PPMConcat(nn.ModuleList): """Pyramid Pooling Module that only concat the features of each layer. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. """ def __init__(self, pool_scales=(1, 3, 6, 8)): super(PPMConcat, self).__init__( [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) def forward(self, feats): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(feats) ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) concat_outs = torch.cat(ppm_outs, dim=2) return concat_outs class SelfAttentionBlock(_SelfAttentionBlock): """Make a ANN used SelfAttentionBlock. Args: low_in_channels (int): Input channels of lower level feature, which is the key feature for self-attention. high_in_channels (int): Input channels of higher level feature, which is the query feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. share_key_query (bool): Whether share projection weight between key and query projection. query_scale (int): The scale of query feature map. key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, low_in_channels, high_in_channels, channels, out_channels, share_key_query, query_scale, key_pool_scales, conv_cfg, norm_cfg, act_cfg): key_psp = PPMConcat(key_pool_scales) if query_scale > 1: query_downsample = nn.MaxPool2d(kernel_size=query_scale) else: query_downsample = None super(SelfAttentionBlock, self).__init__( key_in_channels=low_in_channels, query_in_channels=high_in_channels, channels=channels, out_channels=out_channels, share_key_query=share_key_query, query_downsample=query_downsample, key_downsample=key_psp, key_query_num_convs=1, key_query_norm=True, value_out_num_convs=1, value_out_norm=False, matmul_norm=True, with_out=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) class AFNB(nn.Module): """Asymmetric Fusion Non-local Block(AFNB) Args: low_in_channels (int): Input channels of lower level feature, which is the key feature for self-attention. high_in_channels (int): Input channels of higher level feature, which is the query feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. and query projection. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, low_in_channels, high_in_channels, channels, out_channels, query_scales, key_pool_scales, conv_cfg, norm_cfg, act_cfg): super(AFNB, self).__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( SelfAttentionBlock( low_in_channels=low_in_channels, high_in_channels=high_in_channels, channels=channels, out_channels=out_channels, share_key_query=False, query_scale=query_scale, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.bottleneck = ConvModule( out_channels + high_in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) def forward(self, low_feats, high_feats): """Forward function.""" priors = [stage(high_feats, low_feats) for stage in self.stages] context = torch.stack(priors, dim=0).sum(dim=0) output = self.bottleneck(torch.cat([context, high_feats], 1)) return output class APNB(nn.Module): """Asymmetric Pyramid Non-local Block (APNB) Args: in_channels (int): Input channels of key/query feature, which is the key feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, in_channels, channels, out_channels, query_scales, key_pool_scales, conv_cfg, norm_cfg, act_cfg): super(APNB, self).__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( SelfAttentionBlock( low_in_channels=in_channels, high_in_channels=in_channels, channels=channels, out_channels=out_channels, share_key_query=True, query_scale=query_scale, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.bottleneck = ConvModule( 2 * in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) def forward(self, feats): """Forward function.""" priors = [stage(feats, feats) for stage in self.stages] context = torch.stack(priors, dim=0).sum(dim=0) output = self.bottleneck(torch.cat([context, feats], 1)) return output @HEADS.register_module() class ANNHead(BaseDecodeHead): """Asymmetric Non-local Neural Networks for Semantic Segmentation. This head is the implementation of `ANNNet `_. Args: project_channels (int): Projection channels for Nonlocal. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): The pooling scales of key feature map. Default: (1, 3, 6, 8). """ def __init__(self, project_channels, query_scales=(1, ), key_pool_scales=(1, 3, 6, 8), **kwargs): super(ANNHead, self).__init__( input_transform='multiple_select', **kwargs) assert len(self.in_channels) == 2 low_in_channels, high_in_channels = self.in_channels self.project_channels = project_channels self.fusion = AFNB( low_in_channels=low_in_channels, high_in_channels=high_in_channels, out_channels=high_in_channels, channels=project_channels, query_scales=query_scales, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( high_in_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.context = APNB( in_channels=self.channels, out_channels=self.channels, channels=project_channels, query_scales=query_scales, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" low_feats, high_feats = self._transform_inputs(inputs) output = self.fusion(low_feats, high_feats) output = self.dropout(output) output = self.bottleneck(output) output = self.context(output) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/apc_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead class ACM(nn.Module): """Adaptive Context Module used in APCNet. Args: pool_scale (int): Pooling scale used in Adaptive Context Module to extract region fetures. fusion (bool): Add one conv to fuse residual feature. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict | None): Config of conv layers. norm_cfg (dict | None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, norm_cfg, act_cfg): super(ACM, self).__init__() self.pool_scale = pool_scale self.fusion = fusion self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.pooled_redu_conv = ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.input_redu_conv = ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.global_info = ConvModule( self.channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) self.residual_conv = ConvModule( self.channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.fusion: self.fusion_conv = ConvModule( self.channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, x): """Forward function.""" pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) # [batch_size, channels, h, w] x = self.input_redu_conv(x) # [batch_size, channels, pool_scale, pool_scale] pooled_x = self.pooled_redu_conv(pooled_x) batch_size = x.size(0) # [batch_size, pool_scale * pool_scale, channels] pooled_x = pooled_x.view(batch_size, self.channels, -1).permute(0, 2, 1).contiguous() # [batch_size, h * w, pool_scale * pool_scale] affinity_matrix = self.gla(x + resize( self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) ).permute(0, 2, 3, 1).reshape( batch_size, -1, self.pool_scale**2) affinity_matrix = F.sigmoid(affinity_matrix) # [batch_size, h * w, channels] z_out = torch.matmul(affinity_matrix, pooled_x) # [batch_size, channels, h * w] z_out = z_out.permute(0, 2, 1).contiguous() # [batch_size, channels, h, w] z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) z_out = self.residual_conv(z_out) z_out = F.relu(z_out + x) if self.fusion: z_out = self.fusion_conv(z_out) return z_out @HEADS.register_module() class APCHead(BaseDecodeHead): """Adaptive Pyramid Context Network for Semantic Segmentation. This head is the implementation of `APCNet `_. Args: pool_scales (tuple[int]): Pooling scales used in Adaptive Context Module. Default: (1, 2, 3, 6). fusion (bool): Add one conv to fuse residual feature. """ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): super(APCHead, self).__init__(**kwargs) assert isinstance(pool_scales, (list, tuple)) self.pool_scales = pool_scales self.fusion = fusion acm_modules = [] for pool_scale in self.pool_scales: acm_modules.append( ACM(pool_scale, self.fusion, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.acm_modules = nn.ModuleList(acm_modules) self.bottleneck = ConvModule( self.in_channels + len(pool_scales) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) acm_outs = [x] for acm_module in self.acm_modules: acm_outs.append(acm_module(x)) acm_outs = torch.cat(acm_outs, dim=1) output = self.bottleneck(acm_outs) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/aspp_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead class ASPPModule(nn.ModuleList): """Atrous Spatial Pyramid Pooling (ASPP) Module. Args: dilations (tuple[int]): Dilation rate of each layer. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, act_cfg): super(ASPPModule, self).__init__() self.dilations = dilations self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg for dilation in dilations: self.append( ConvModule( self.in_channels, self.channels, 1 if dilation == 1 else 3, dilation=dilation, padding=0 if dilation == 1 else dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) def forward(self, x): """Forward function.""" aspp_outs = [] for aspp_module in self: aspp_outs.append(aspp_module(x)) return aspp_outs @HEADS.register_module() class ASPPHead(BaseDecodeHead): """Rethinking Atrous Convolution for Semantic Image Segmentation. This head is the implementation of `DeepLabV3 `_. Args: dilations (tuple[int]): Dilation rates for ASPP module. Default: (1, 6, 12, 18). """ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): super(ASPPHead, self).__init__(**kwargs) assert isinstance(dilations, (list, tuple)) self.dilations = dilations self.image_pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.aspp_modules = ASPPModule( dilations, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( (len(dilations) + 1) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cascade_decode_head.py ================================================ from abc import ABCMeta, abstractmethod from .decode_head import BaseDecodeHead class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): """Base class for cascade decode head used in :class:`CascadeEncoderDecoder.""" def __init__(self, *args, **kwargs): super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) @abstractmethod def forward(self, inputs, prev_output): """Placeholder of forward function.""" pass def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs, prev_output) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, prev_output, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs, prev_output) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/cc_head.py ================================================ import torch from ..builder import HEADS from .fcn_head import FCNHead try: from mmcv.ops import CrissCrossAttention except ModuleNotFoundError: CrissCrossAttention = None @HEADS.register_module() class CCHead(FCNHead): """CCNet: Criss-Cross Attention for Semantic Segmentation. This head is the implementation of `CCNet `_. Args: recurrence (int): Number of recurrence of Criss Cross Attention module. Default: 2. """ def __init__(self, recurrence=2, **kwargs): if CrissCrossAttention is None: raise RuntimeError('Please install mmcv-full for ' 'CrissCrossAttention ops') super(CCHead, self).__init__(num_convs=2, **kwargs) self.recurrence = recurrence self.cca = CrissCrossAttention(self.channels) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs[0](x) for _ in range(self.recurrence): output = self.cca(output) output = self.convs[1](output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/da_head.py ================================================ import torch import torch.nn.functional as F from mmcv.cnn import ConvModule, Scale from torch import nn from mmseg.core import add_prefix from ..builder import HEADS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead class PAM(_SelfAttentionBlock): """Position Attention Module (PAM) Args: in_channels (int): Input channels of key/query feature. channels (int): Output channels of key/query transform. """ def __init__(self, in_channels, channels): super(PAM, self).__init__( key_in_channels=in_channels, query_in_channels=in_channels, channels=channels, out_channels=in_channels, share_key_query=False, query_downsample=None, key_downsample=None, key_query_num_convs=1, key_query_norm=False, value_out_num_convs=1, value_out_norm=False, matmul_norm=False, with_out=False, conv_cfg=None, norm_cfg=None, act_cfg=None) self.gamma = Scale(0) def forward(self, x): """Forward function.""" out = super(PAM, self).forward(x, x) out = self.gamma(out) + x return out class CAM(nn.Module): """Channel Attention Module (CAM)""" def __init__(self): super(CAM, self).__init__() self.gamma = Scale(0) def forward(self, x): """Forward function.""" batch_size, channels, height, width = x.size() proj_query = x.view(batch_size, channels, -1) proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) energy = torch.bmm(proj_query, proj_key) energy_new = torch.max( energy, -1, keepdim=True)[0].expand_as(energy) - energy attention = F.softmax(energy_new, dim=-1) proj_value = x.view(batch_size, channels, -1) out = torch.bmm(attention, proj_value) out = out.view(batch_size, channels, height, width) out = self.gamma(out) + x return out @HEADS.register_module() class DAHead(BaseDecodeHead): """Dual Attention Network for Scene Segmentation. This head is the implementation of `DANet `_. Args: pam_channels (int): The channels of Position Attention Module(PAM). """ def __init__(self, pam_channels, **kwargs): super(DAHead, self).__init__(**kwargs) self.pam_channels = pam_channels self.pam_in_conv = ConvModule( self.in_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.pam = PAM(self.channels, pam_channels) self.pam_out_conv = ConvModule( self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.pam_conv_seg = nn.Conv2d( self.channels, self.num_classes, kernel_size=1) self.cam_in_conv = ConvModule( self.in_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.cam = CAM() self.cam_out_conv = ConvModule( self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.cam_conv_seg = nn.Conv2d( self.channels, self.num_classes, kernel_size=1) def pam_cls_seg(self, feat): """PAM feature classification.""" if self.dropout is not None: feat = self.dropout(feat) output = self.pam_conv_seg(feat) return output def cam_cls_seg(self, feat): """CAM feature classification.""" if self.dropout is not None: feat = self.dropout(feat) output = self.cam_conv_seg(feat) return output def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) pam_feat = self.pam_in_conv(x) pam_feat = self.pam(pam_feat) pam_feat = self.pam_out_conv(pam_feat) pam_out = self.pam_cls_seg(pam_feat) cam_feat = self.cam_in_conv(x) cam_feat = self.cam(cam_feat) cam_feat = self.cam_out_conv(cam_feat) cam_out = self.cam_cls_seg(cam_feat) feat_sum = pam_feat + cam_feat pam_cam_out = self.cls_seg(feat_sum) return pam_cam_out, pam_out, cam_out def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing, only ``pam_cam`` is used.""" return self.forward(inputs)[0] def losses(self, seg_logit, seg_label): """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit loss = dict() loss.update( add_prefix( super(DAHead, self).losses(pam_cam_seg_logit, seg_label), 'pam_cam')) loss.update( add_prefix( super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) loss.update( add_prefix( super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) return loss ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/decode_head.py ================================================ from abc import ABCMeta, abstractmethod import torch import torch.nn as nn from mmcv.cnn import normal_init from mmcv.runner import auto_fp16, force_fp32 from mmseg.core import build_pixel_sampler from mmseg.ops import resize from ..builder import build_loss from ..losses import accuracy class BaseDecodeHead(nn.Module, metaclass=ABCMeta): """Base class for BaseDecodeHead. Args: in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. loss_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Default: 255 sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. """ def __init__(self, in_channels, channels, *, num_classes, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), ignore_index=255, sampler=None, align_corners=False): super(BaseDecodeHead, self).__init__() self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index self.loss_decode = build_loss(loss_decode) self.ignore_index = ignore_index self.align_corners = align_corners if sampler is not None: self.sampler = build_pixel_sampler(sampler, context=self) else: self.sampler = None self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None self.fp16_enabled = False def extra_repr(self): """Extra repr.""" s = f'input_transform={self.input_transform}, ' \ f'ignore_index={self.ignore_index}, ' \ f'align_corners={self.align_corners}' return s def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def init_weights(self): """Initialize weights of classification layer.""" normal_init(self.conv_seg, mean=0, std=0.01) def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs @auto_fp16() @abstractmethod def forward(self, inputs): """Placeholder of forward function.""" pass def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs) def cls_seg(self, feat): """Classify each pixel.""" if self.dropout is not None: feat = self.dropout(feat) output = self.conv_seg(feat) return output @force_fp32(apply_to=('seg_logit', )) def losses(self, seg_logit, seg_label): """Compute segmentation loss.""" loss = dict() seg_logit = resize( input=seg_logit, size=seg_label.shape[2:], mode='bilinear', align_corners=self.align_corners) if self.sampler is not None: seg_weight = self.sampler.sample(seg_logit, seg_label) else: seg_weight = None seg_label = seg_label.squeeze(1) loss['loss_seg'] = self.loss_decode( seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index) loss['acc_seg'] = accuracy(seg_logit, seg_label) return loss ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dm_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer from ..builder import HEADS from .decode_head import BaseDecodeHead class DCM(nn.Module): """Dynamic Convolutional Module used in DMNet. Args: filter_size (int): The filter size of generated convolution kernel used in Dynamic Convolutional Module. fusion (bool): Add one conv to fuse DCM output feature. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict | None): Config of conv layers. norm_cfg (dict | None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, norm_cfg, act_cfg): super(DCM, self).__init__() self.filter_size = filter_size self.fusion = fusion self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, 0) self.input_redu_conv = ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.norm_cfg is not None: self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] else: self.norm = None self.activate = build_activation_layer(self.act_cfg) if self.fusion: self.fusion_conv = ConvModule( self.channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, x): """Forward function.""" generted_filter = self.filter_gen_conv( F.adaptive_avg_pool2d(x, self.filter_size)) x = self.input_redu_conv(x) b, c, h, w = x.shape # [1, b * c, h, w], c = self.channels x = x.view(1, b * c, h, w) # [b * c, 1, filter_size, filter_size] generted_filter = generted_filter.view(b * c, 1, self.filter_size, self.filter_size) pad = (self.filter_size - 1) // 2 if (self.filter_size - 1) % 2 == 0: p2d = (pad, pad, pad, pad) else: p2d = (pad + 1, pad, pad + 1, pad) x = F.pad(input=x, pad=p2d, mode='constant', value=0) # [1, b * c, h, w] output = F.conv2d(input=x, weight=generted_filter, groups=b * c) # [b, c, h, w] output = output.view(b, c, h, w) if self.norm is not None: output = self.norm(output) output = self.activate(output) if self.fusion: output = self.fusion_conv(output) return output @HEADS.register_module() class DMHead(BaseDecodeHead): """Dynamic Multi-scale Filters for Semantic Segmentation. This head is the implementation of `DMNet `_. Args: filter_sizes (tuple[int]): The size of generated convolutional filters used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). fusion (bool): Add one conv to fuse DCM output feature. """ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): super(DMHead, self).__init__(**kwargs) assert isinstance(filter_sizes, (list, tuple)) self.filter_sizes = filter_sizes self.fusion = fusion dcm_modules = [] for filter_size in self.filter_sizes: dcm_modules.append( DCM(filter_size, self.fusion, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.dcm_modules = nn.ModuleList(dcm_modules) self.bottleneck = ConvModule( self.in_channels + len(filter_sizes) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) dcm_outs = [x] for dcm_module in self.dcm_modules: dcm_outs.append(dcm_module(x)) dcm_outs = torch.cat(dcm_outs, dim=1) output = self.bottleneck(dcm_outs) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/dnl_head.py ================================================ import torch from mmcv.cnn import NonLocal2d from torch import nn from ..builder import HEADS from .fcn_head import FCNHead class DisentangledNonLocal2d(NonLocal2d): """Disentangled Non-Local Blocks. Args: temperature (float): Temperature to adjust attention. Default: 0.05 """ def __init__(self, *arg, temperature, **kwargs): super().__init__(*arg, **kwargs) self.temperature = temperature self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) def embedded_gaussian(self, theta_x, phi_x): """Embedded gaussian with temperature.""" # NonLocal2d pairwise_weight: [N, HxW, HxW] pairwise_weight = torch.matmul(theta_x, phi_x) if self.use_scale: # theta_x.shape[-1] is `self.inter_channels` pairwise_weight /= theta_x.shape[-1]**0.5 pairwise_weight /= self.temperature pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight def forward(self, x): # x: [N, C, H, W] n = x.size(0) # g_x: [N, HxW, C] g_x = self.g(x).view(n, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) # theta_x: [N, HxW, C], phi_x: [N, C, HxW] if self.mode == 'gaussian': theta_x = x.view(n, self.in_channels, -1) theta_x = theta_x.permute(0, 2, 1) if self.sub_sample: phi_x = self.phi(x).view(n, self.in_channels, -1) else: phi_x = x.view(n, self.in_channels, -1) elif self.mode == 'concatenation': theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) else: theta_x = self.theta(x).view(n, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(n, self.inter_channels, -1) # subtract mean theta_x -= theta_x.mean(dim=-2, keepdim=True) phi_x -= phi_x.mean(dim=-1, keepdim=True) pairwise_func = getattr(self, self.mode) # pairwise_weight: [N, HxW, HxW] pairwise_weight = pairwise_func(theta_x, phi_x) # y: [N, HxW, C] y = torch.matmul(pairwise_weight, g_x) # y: [N, C, H, W] y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, *x.size()[2:]) # unary_mask: [N, 1, HxW] unary_mask = self.conv_mask(x) unary_mask = unary_mask.view(n, 1, -1) unary_mask = unary_mask.softmax(dim=-1) # unary_x: [N, 1, C] unary_x = torch.matmul(unary_mask, g_x) # unary_x: [N, C, 1, 1] unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( n, self.inter_channels, 1, 1) output = x + self.conv_out(y + unary_x) return output @HEADS.register_module() class DNLHead(FCNHead): """Disentangled Non-Local Neural Networks. This head is the implementation of `DNLNet `_. Args: reduction (int): Reduction factor of projection transform. Default: 2. use_scale (bool): Whether to scale pairwise_weight by sqrt(1/inter_channels). Default: False. mode (str): The nonlocal mode. Options are 'embedded_gaussian', 'dot_product'. Default: 'embedded_gaussian.'. temperature (float): Temperature to adjust attention. Default: 0.05 """ def __init__(self, reduction=2, use_scale=True, mode='embedded_gaussian', temperature=0.05, **kwargs): super(DNLHead, self).__init__(num_convs=2, **kwargs) self.reduction = reduction self.use_scale = use_scale self.mode = mode self.temperature = temperature self.dnl_block = DisentangledNonLocal2d( in_channels=self.channels, reduction=self.reduction, use_scale=self.use_scale, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, mode=self.mode, temperature=self.temperature) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs[0](x) output = self.dnl_block(output) output = self.convs[1](output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ema_head.py ================================================ import math import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from ..builder import HEADS from .decode_head import BaseDecodeHead def reduce_mean(tensor): """Reduce mean when distributed training.""" if not (dist.is_available() and dist.is_initialized()): return tensor tensor = tensor.clone() dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) return tensor class EMAModule(nn.Module): """Expectation Maximization Attention Module used in EMANet. Args: channels (int): Channels of the whole module. num_bases (int): Number of bases. num_stages (int): Number of the EM iterations. """ def __init__(self, channels, num_bases, num_stages, momentum): super(EMAModule, self).__init__() assert num_stages >= 1, 'num_stages must be at least 1!' self.num_bases = num_bases self.num_stages = num_stages self.momentum = momentum bases = torch.zeros(1, channels, self.num_bases) bases.normal_(0, math.sqrt(2. / self.num_bases)) # [1, channels, num_bases] bases = F.normalize(bases, dim=1, p=2) self.register_buffer('bases', bases) def forward(self, feats): """Forward function.""" batch_size, channels, height, width = feats.size() # [batch_size, channels, height*width] feats = feats.view(batch_size, channels, height * width) # [batch_size, channels, num_bases] bases = self.bases.repeat(batch_size, 1, 1) with torch.no_grad(): for i in range(self.num_stages): # [batch_size, height*width, num_bases] attention = torch.einsum('bcn,bck->bnk', feats, bases) attention = F.softmax(attention, dim=2) # l1 norm attention_normed = F.normalize(attention, dim=1, p=1) # [batch_size, channels, num_bases] bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) # l2 norm bases = F.normalize(bases, dim=1, p=2) feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) feats_recon = feats_recon.view(batch_size, channels, height, width) if self.training: bases = bases.mean(dim=0, keepdim=True) bases = reduce_mean(bases) # l2 norm bases = F.normalize(bases, dim=1, p=2) self.bases = (1 - self.momentum) * self.bases + self.momentum * bases return feats_recon @HEADS.register_module() class EMAHead(BaseDecodeHead): """Expectation Maximization Attention Networks for Semantic Segmentation. This head is the implementation of `EMANet `_. Args: ema_channels (int): EMA module channels num_bases (int): Number of bases. num_stages (int): Number of the EM iterations. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True momentum (float): Momentum to update the base. Default: 0.1. """ def __init__(self, ema_channels, num_bases, num_stages, concat_input=True, momentum=0.1, **kwargs): super(EMAHead, self).__init__(**kwargs) self.ema_channels = ema_channels self.num_bases = num_bases self.num_stages = num_stages self.concat_input = concat_input self.momentum = momentum self.ema_module = EMAModule(self.ema_channels, self.num_bases, self.num_stages, self.momentum) self.ema_in_conv = ConvModule( self.in_channels, self.ema_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) # project (0, inf) -> (-inf, inf) self.ema_mid_conv = ConvModule( self.ema_channels, self.ema_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None) for param in self.ema_mid_conv.parameters(): param.requires_grad = False self.ema_out_conv = ConvModule( self.ema_channels, self.ema_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=None) self.bottleneck = ConvModule( self.ema_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.concat_input: self.conv_cat = ConvModule( self.in_channels + self.channels, self.channels, kernel_size=3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) feats = self.ema_in_conv(x) identity = feats feats = self.ema_mid_conv(feats) recon = self.ema_module(feats) recon = F.relu(recon, inplace=True) recon = self.ema_out_conv(recon) output = F.relu(identity + recon, inplace=True) output = self.bottleneck(output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/enc_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer from mmseg.ops import Encoding, resize from ..builder import HEADS, build_loss from .decode_head import BaseDecodeHead class EncModule(nn.Module): """Encoding Module used in EncNet. Args: in_channels (int): Input channels. num_codes (int): Number of code words. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): super(EncModule, self).__init__() self.encoding_project = ConvModule( in_channels, in_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) # TODO: resolve this hack # change to 1d if norm_cfg is not None: encoding_norm_cfg = norm_cfg.copy() if encoding_norm_cfg['type'] in ['BN', 'IN']: encoding_norm_cfg['type'] += '1d' else: encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( '2d', '1d') else: # fallback to BN1d encoding_norm_cfg = dict(type='BN1d') self.encoding = nn.Sequential( Encoding(channels=in_channels, num_codes=num_codes), build_norm_layer(encoding_norm_cfg, num_codes)[1], nn.ReLU(inplace=True)) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels), nn.Sigmoid()) def forward(self, x): """Forward function.""" encoding_projection = self.encoding_project(x) encoding_feat = self.encoding(encoding_projection).mean(dim=1) batch_size, channels, _, _ = x.size() gamma = self.fc(encoding_feat) y = gamma.view(batch_size, channels, 1, 1) output = F.relu_(x + x * y) return encoding_feat, output @HEADS.register_module() class EncHead(BaseDecodeHead): """Context Encoding for Semantic Segmentation. This head is the implementation of `EncNet `_. Args: num_codes (int): Number of code words. Default: 32. use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to regularize the training. Default: True. add_lateral (bool): Whether use lateral connection to fuse features. Default: False. loss_se_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss', use_sigmoid=True). """ def __init__(self, num_codes=32, use_se_loss=True, add_lateral=False, loss_se_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2), **kwargs): super(EncHead, self).__init__( input_transform='multiple_select', **kwargs) self.use_se_loss = use_se_loss self.add_lateral = add_lateral self.num_codes = num_codes self.bottleneck = ConvModule( self.in_channels[-1], self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if add_lateral: self.lateral_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the last one self.lateral_convs.append( ConvModule( in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion = ConvModule( len(self.in_channels) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.enc_module = EncModule( self.channels, num_codes=num_codes, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.use_se_loss: self.loss_se_decode = build_loss(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes) def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) feat = self.bottleneck(inputs[-1]) if self.add_lateral: laterals = [ resize( lateral_conv(inputs[i]), size=feat.shape[2:], mode='bilinear', align_corners=self.align_corners) for i, lateral_conv in enumerate(self.lateral_convs) ] feat = self.fusion(torch.cat([feat, *laterals], 1)) encode_feat, output = self.enc_module(feat) output = self.cls_seg(output) if self.use_se_loss: se_output = self.se_layer(encode_feat) return output, se_output else: return output def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing, ignore se_loss.""" if self.use_se_loss: return self.forward(inputs)[0] else: return self.forward(inputs) @staticmethod def _convert_to_onehot_labels(seg_label, num_classes): """Convert segmentation label to onehot. Args: seg_label (Tensor): Segmentation label of shape (N, H, W). num_classes (int): Number of classes. Returns: Tensor: Onehot labels of shape (N, num_classes). """ batch_size = seg_label.size(0) onehot_labels = seg_label.new_zeros((batch_size, num_classes)) for i in range(batch_size): hist = seg_label[i].float().histc( bins=num_classes, min=0, max=num_classes - 1) onehot_labels[i] = hist > 0 return onehot_labels def losses(self, seg_logit, seg_label): """Compute segmentation and semantic encoding loss.""" seg_logit, se_seg_logit = seg_logit loss = dict() loss.update(super(EncHead, self).losses(seg_logit, seg_label)) se_loss = self.loss_se_decode( se_seg_logit, self._convert_to_onehot_labels(seg_label, self.num_classes)) loss['loss_se'] = se_loss return loss ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fcn_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule from ..builder import HEADS from .decode_head import BaseDecodeHead @HEADS.register_module() class FCNHead(BaseDecodeHead): """Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet `_. Args: num_convs (int): Number of convs in the head. Default: 2. kernel_size (int): The kernel size for convs in the head. Default: 3. concat_input (bool): Whether concat the input and output of convs before classification layer. """ def __init__(self, num_convs=2, kernel_size=3, concat_input=True, **kwargs): assert num_convs >= 0 self.num_convs = num_convs self.concat_input = concat_input self.kernel_size = kernel_size super(FCNHead, self).__init__(**kwargs) if num_convs == 0: assert self.in_channels == self.channels convs = [] convs.append( ConvModule( self.in_channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) for i in range(num_convs - 1): convs.append( ConvModule( self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if num_convs == 0: self.convs = nn.Identity() else: self.convs = nn.Sequential(*convs) if self.concat_input: self.conv_cat = ConvModule( self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs(x) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/fpn_head.py ================================================ import numpy as np import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead @HEADS.register_module() class FPNHead(BaseDecodeHead): """Panoptic Feature Pyramid Networks. This head is the implementation of `Semantic FPN `_. Args: feature_strides (tuple[int]): The strides for input feature maps. stack_lateral. All strides suppose to be power of 2. The first one is of largest resolution. """ def __init__(self, feature_strides, **kwargs): super(FPNHead, self).__init__( input_transform='multiple_select', **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides self.scale_heads = nn.ModuleList() for i in range(len(feature_strides)): head_length = max( 1, int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) scale_head = [] for k in range(head_length): scale_head.append( ConvModule( self.in_channels[i] if k == 0 else self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if feature_strides[i] != feature_strides[0]: scale_head.append( nn.Upsample( scale_factor=2, mode='bilinear', align_corners=self.align_corners)) self.scale_heads.append(nn.Sequential(*scale_head)) def forward(self, inputs): x = self._transform_inputs(inputs) output = self.scale_heads[0](x[0]) for i in range(1, len(self.feature_strides)): # non inplace output = output + resize( self.scale_heads[i](x[i]), size=output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/gc_head.py ================================================ import torch from mmcv.cnn import ContextBlock from ..builder import HEADS from .fcn_head import FCNHead @HEADS.register_module() class GCHead(FCNHead): """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. This head is the implementation of `GCNet `_. Args: ratio (float): Multiplier of channels ratio. Default: 1/4. pooling_type (str): The pooling type of context aggregation. Options are 'att', 'avg'. Default: 'avg'. fusion_types (tuple[str]): The fusion type for feature fusion. Options are 'channel_add', 'channel_mul'. Defautl: ('channel_add',) """ def __init__(self, ratio=1 / 4., pooling_type='att', fusion_types=('channel_add', ), **kwargs): super(GCHead, self).__init__(num_convs=2, **kwargs) self.ratio = ratio self.pooling_type = pooling_type self.fusion_types = fusion_types self.gc_block = ContextBlock( in_channels=self.channels, ratio=self.ratio, pooling_type=self.pooling_type, fusion_types=self.fusion_types) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs[0](x) output = self.gc_block(output) output = self.convs[1](output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/lraspp_head.py ================================================ import torch import torch.nn as nn from mmcv import is_tuple_of from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead @HEADS.register_module() class LRASPPHead(BaseDecodeHead): """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. This head is the improved implementation of `Searching for MobileNetV3 `_. Args: branch_channels (tuple[int]): The number of output channels in every each branch. Default: (32, 64). """ def __init__(self, branch_channels=(32, 64), **kwargs): super(LRASPPHead, self).__init__(**kwargs) if self.input_transform != 'multiple_select': raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' f'must be \'multiple_select\'. But received ' f'\'{self.input_transform}\'') assert is_tuple_of(branch_channels, int) assert len(branch_channels) == len(self.in_channels) - 1 self.branch_channels = branch_channels self.convs = nn.Sequential() self.conv_ups = nn.Sequential() for i in range(len(branch_channels)): self.convs.add_module( f'conv{i}', nn.Conv2d( self.in_channels[i], branch_channels[i], 1, bias=False)) self.conv_ups.add_module( f'conv_up{i}', ConvModule( self.channels + branch_channels[i], self.channels, 1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, bias=False)) self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) self.aspp_conv = ConvModule( self.in_channels[-1], self.channels, 1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, bias=False) self.image_pool = nn.Sequential( nn.AvgPool2d(kernel_size=49, stride=(16, 20)), ConvModule( self.in_channels[2], self.channels, 1, act_cfg=dict(type='Sigmoid'), bias=False)) def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) x = inputs[-1] x = self.aspp_conv(x) * resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) x = self.conv_up_input(x) for i in range(len(self.branch_channels) - 1, -1, -1): x = resize( x, size=inputs[i].size()[2:], mode='bilinear', align_corners=self.align_corners) x = torch.cat([x, self.convs[i](inputs[i])], 1) x = self.conv_ups[i](x) return self.cls_seg(x) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/nl_head.py ================================================ import torch from mmcv.cnn import NonLocal2d from ..builder import HEADS from .fcn_head import FCNHead @HEADS.register_module() class NLHead(FCNHead): """Non-local Neural Networks. This head is the implementation of `NLNet `_. Args: reduction (int): Reduction factor of projection transform. Default: 2. use_scale (bool): Whether to scale pairwise_weight by sqrt(1/inter_channels). Default: True. mode (str): The nonlocal mode. Options are 'embedded_gaussian', 'dot_product'. Default: 'embedded_gaussian.'. """ def __init__(self, reduction=2, use_scale=True, mode='embedded_gaussian', **kwargs): super(NLHead, self).__init__(num_convs=2, **kwargs) self.reduction = reduction self.use_scale = use_scale self.mode = mode self.nl_block = NonLocal2d( in_channels=self.channels, reduction=self.reduction, use_scale=self.use_scale, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, mode=self.mode) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs[0](x) output = self.nl_block(output) output = self.convs[1](output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/ocr_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .cascade_decode_head import BaseCascadeDecodeHead class SpatialGatherModule(nn.Module): """Aggregate the context features according to the initial predicted probability distribution. Employ the soft-weighted method to aggregate the context. """ def __init__(self, scale): super(SpatialGatherModule, self).__init__() self.scale = scale def forward(self, feats, probs): """Forward function.""" batch_size, num_classes, height, width = probs.size() channels = feats.size(1) probs = probs.view(batch_size, num_classes, -1) feats = feats.view(batch_size, channels, -1) # [batch_size, height*width, num_classes] feats = feats.permute(0, 2, 1) # [batch_size, channels, height*width] probs = F.softmax(self.scale * probs, dim=2) # [batch_size, channels, num_classes] ocr_context = torch.matmul(probs, feats) ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) return ocr_context class ObjectAttentionBlock(_SelfAttentionBlock): """Make a OCR used SelfAttentionBlock.""" def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, act_cfg): if scale > 1: query_downsample = nn.MaxPool2d(kernel_size=scale) else: query_downsample = None super(ObjectAttentionBlock, self).__init__( key_in_channels=in_channels, query_in_channels=in_channels, channels=channels, out_channels=in_channels, share_key_query=False, query_downsample=query_downsample, key_downsample=None, key_query_num_convs=2, key_query_norm=True, value_out_num_convs=1, value_out_norm=True, matmul_norm=True, with_out=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.bottleneck = ConvModule( in_channels * 2, in_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, query_feats, key_feats): """Forward function.""" context = super(ObjectAttentionBlock, self).forward(query_feats, key_feats) output = self.bottleneck(torch.cat([context, query_feats], dim=1)) if self.query_downsample is not None: output = resize(query_feats) return output @HEADS.register_module() class OCRHead(BaseCascadeDecodeHead): """Object-Contextual Representations for Semantic Segmentation. This head is the implementation of `OCRNet `_. Args: ocr_channels (int): The intermediate channels of OCR block. scale (int): The scale of probability map in SpatialGatherModule in Default: 1. """ def __init__(self, ocr_channels, scale=1, **kwargs): super(OCRHead, self).__init__(**kwargs) self.ocr_channels = ocr_channels self.scale = scale self.object_context_block = ObjectAttentionBlock( self.channels, self.ocr_channels, self.scale, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.spatial_gather_module = SpatialGatherModule(self.scale) self.bottleneck = ConvModule( self.in_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs, prev_output): """Forward function.""" x = self._transform_inputs(inputs) feats = self.bottleneck(x) context = self.spatial_gather_module(feats, prev_output) object_context = self.object_context_block(feats, context) output = self.cls_seg(object_context) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/point_head.py ================================================ # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa import torch import torch.nn as nn from mmcv.cnn import ConvModule, normal_init from mmcv.ops import point_sample from mmseg.models.builder import HEADS from mmseg.ops import resize from ..losses import accuracy from .cascade_decode_head import BaseCascadeDecodeHead def calculate_uncertainty(seg_logits): """Estimate uncertainty based on seg logits. For each location of the prediction ``seg_logits`` we estimate uncertainty as the difference between top first and top second predicted logits. Args: seg_logits (Tensor): Semantic segmentation logits, shape (batch_size, num_classes, height, width). Returns: scores (Tensor): T uncertainty scores with the most uncertain locations having the highest uncertainty score, shape ( batch_size, 1, height, width) """ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) @HEADS.register_module() class PointHead(BaseCascadeDecodeHead): """A mask point head use in PointRend. ``PointHead`` use shared multi-layer perceptron (equivalent to nn.Conv1d) to predict the logit of input points. The fine-grained feature and coarse feature will be concatenate together for predication. Args: num_fcs (int): Number of fc layers in the head. Default: 3. in_channels (int): Number of input channels. Default: 256. fc_channels (int): Number of fc channels. Default: 256. num_classes (int): Number of classes for logits. Default: 80. class_agnostic (bool): Whether use class agnostic classification. If so, the output channels of logits will be 1. Default: False. coarse_pred_each_layer (bool): Whether concatenate coarse feature with the output of each fc layer. Default: True. conv_cfg (dict|None): Dictionary to construct and config conv layer. Default: dict(type='Conv1d')) norm_cfg (dict|None): Dictionary to construct and config norm layer. Default: None. loss_point (dict): Dictionary to construct and config loss layer of point head. Default: dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0). """ def __init__(self, num_fcs=3, coarse_pred_each_layer=True, conv_cfg=dict(type='Conv1d'), norm_cfg=None, act_cfg=dict(type='ReLU', inplace=False), **kwargs): super(PointHead, self).__init__( input_transform='multiple_select', conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, **kwargs) self.num_fcs = num_fcs self.coarse_pred_each_layer = coarse_pred_each_layer fc_in_channels = sum(self.in_channels) + self.num_classes fc_channels = self.channels self.fcs = nn.ModuleList() for k in range(num_fcs): fc = ConvModule( fc_in_channels, fc_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.fcs.append(fc) fc_in_channels = fc_channels fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ else 0 self.fc_seg = nn.Conv1d( fc_in_channels, self.num_classes, kernel_size=1, stride=1, padding=0) if self.dropout_ratio > 0: self.dropout = nn.Dropout(self.dropout_ratio) delattr(self, 'conv_seg') def init_weights(self): """Initialize weights of classification layer.""" normal_init(self.fc_seg, std=0.001) def cls_seg(self, feat): """Classify each pixel with fc.""" if self.dropout is not None: feat = self.dropout(feat) output = self.fc_seg(feat) return output def forward(self, fine_grained_point_feats, coarse_point_feats): x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) for fc in self.fcs: x = fc(x) if self.coarse_pred_each_layer: x = torch.cat((x, coarse_point_feats), dim=1) return self.cls_seg(x) def _get_fine_grained_point_feats(self, x, points): """Sample from fine grained features. Args: x (list[Tensor]): Feature pyramid from by neck or backbone. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: fine_grained_feats (Tensor): Sampled fine grained feature, shape (batch_size, sum(channels of x), num_points). """ fine_grained_feats_list = [ point_sample(_, points, align_corners=self.align_corners) for _ in x ] if len(fine_grained_feats_list) > 1: fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) else: fine_grained_feats = fine_grained_feats_list[0] return fine_grained_feats def _get_coarse_point_feats(self, prev_output, points): """Sample from fine grained features. Args: prev_output (list[Tensor]): Prediction of previous decode head. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, num_classes, num_points). """ coarse_feats = point_sample( prev_output, points, align_corners=self.align_corners) return coarse_feats def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self._transform_inputs(inputs) with torch.no_grad(): points = self.get_points_train( prev_output, calculate_uncertainty, cfg=train_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, points) coarse_point_feats = self._get_coarse_point_feats(prev_output, points) point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_label = point_sample( gt_semantic_seg.float(), points, mode='nearest', align_corners=self.align_corners) point_label = point_label.squeeze(1).long() losses = self.losses(point_logits, point_label) return losses def forward_test(self, inputs, prev_output, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ x = self._transform_inputs(inputs) refined_seg_logits = prev_output.clone() for _ in range(test_cfg.subdivision_steps): refined_seg_logits = resize( refined_seg_logits, scale_factor=test_cfg.scale_factor, mode='bilinear', align_corners=self.align_corners) batch_size, channels, height, width = refined_seg_logits.shape point_indices, points = self.get_points_test( refined_seg_logits, calculate_uncertainty, cfg=test_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, points) coarse_point_feats = self._get_coarse_point_feats( prev_output, points) point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_seg_logits = refined_seg_logits.reshape( batch_size, channels, height * width) refined_seg_logits = refined_seg_logits.scatter_( 2, point_indices, point_logits) refined_seg_logits = refined_seg_logits.view( batch_size, channels, height, width) return refined_seg_logits def losses(self, point_logits, point_label): """Compute segmentation loss.""" loss = dict() loss['loss_point'] = self.loss_decode( point_logits, point_label, ignore_index=self.ignore_index) loss['acc_point'] = accuracy(point_logits, point_label) return loss def get_points_train(self, seg_logits, uncertainty_func, cfg): """Sample points for training. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. Args: seg_logits (Tensor): Semantic segmentation logits, shape ( batch_size, num_classes, height, width). uncertainty_func (func): uncertainty calculation function. cfg (dict): Training config of point head. Returns: point_coords (Tensor): A tensor of shape (batch_size, num_points, 2) that contains the coordinates of ``num_points`` sampled points. """ num_points = cfg.num_points oversample_ratio = cfg.oversample_ratio importance_sample_ratio = cfg.importance_sample_ratio assert oversample_ratio >= 1 assert 0 <= importance_sample_ratio <= 1 batch_size = seg_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand( batch_size, num_sampled, 2, device=seg_logits.device) point_logits = point_sample(seg_logits, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the # coarse predictions first and sampling them for points leads to # incorrect results. To illustrate this: assume uncertainty func( # logits)=-abs(logits), a sampled point between two coarse # predictions with -1 and 1 logits has 0 logits, and therefore 0 # uncertainty value. However, if we calculate uncertainties for the # coarse predictions first, both will have -1 uncertainty, # and sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk( point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( batch_size, dtype=torch.long, device=seg_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( batch_size, num_uncertain_points, 2) if num_random_points > 0: rand_point_coords = torch.rand( batch_size, num_random_points, 2, device=seg_logits.device) point_coords = torch.cat((point_coords, rand_point_coords), dim=1) return point_coords def get_points_test(self, seg_logits, uncertainty_func, cfg): """Sample points for testing. Find ``num_points`` most uncertain points from ``uncertainty_map``. Args: seg_logits (Tensor): A tensor of shape (batch_size, num_classes, height, width) for class-specific or class-agnostic prediction. uncertainty_func (func): uncertainty calculation function. cfg (dict): Testing config of point head. Returns: point_indices (Tensor): A tensor of shape (batch_size, num_points) that contains indices from [0, height x width) of the most uncertain points. point_coords (Tensor): A tensor of shape (batch_size, num_points, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the ``height x width`` grid . """ num_points = cfg.subdivision_num_points uncertainty_map = uncertainty_func(seg_logits) batch_size, _, height, width = uncertainty_map.shape h_step = 1.0 / height w_step = 1.0 / width uncertainty_map = uncertainty_map.view(batch_size, height * width) num_points = min(height * width, num_points) point_indices = uncertainty_map.topk(num_points, dim=1)[1] point_coords = torch.zeros( batch_size, num_points, 2, dtype=torch.float, device=seg_logits.device) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % width).float() * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // width).float() * h_step return point_indices, point_coords ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psa_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead try: from mmcv.ops import PSAMask except ModuleNotFoundError: PSAMask = None @HEADS.register_module() class PSAHead(BaseDecodeHead): """Point-wise Spatial Attention Network for Scene Parsing. This head is the implementation of `PSANet `_. Args: mask_size (tuple[int]): The PSA mask size. It usually equals input size. psa_type (str): The type of psa module. Options are 'collect', 'distribute', 'bi-direction'. Default: 'bi-direction' compact (bool): Whether use compact map for 'collect' mode. Default: True. shrink_factor (int): The downsample factors of psa mask. Default: 2. normalization_factor (float): The normalize factor of attention. psa_softmax (bool): Whether use softmax for attention. """ def __init__(self, mask_size, psa_type='bi-direction', compact=False, shrink_factor=2, normalization_factor=1.0, psa_softmax=True, **kwargs): if PSAMask is None: raise RuntimeError('Please install mmcv-full for PSAMask ops') super(PSAHead, self).__init__(**kwargs) assert psa_type in ['collect', 'distribute', 'bi-direction'] self.psa_type = psa_type self.compact = compact self.shrink_factor = shrink_factor self.mask_size = mask_size mask_h, mask_w = mask_size self.psa_softmax = psa_softmax if normalization_factor is None: normalization_factor = mask_h * mask_w self.normalization_factor = normalization_factor self.reduce = ConvModule( self.in_channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.attention = nn.Sequential( ConvModule( self.channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d( self.channels, mask_h * mask_w, kernel_size=1, bias=False)) if psa_type == 'bi-direction': self.reduce_p = ConvModule( self.in_channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.attention_p = nn.Sequential( ConvModule( self.channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d( self.channels, mask_h * mask_w, kernel_size=1, bias=False)) self.psamask_collect = PSAMask('collect', mask_size) self.psamask_distribute = PSAMask('distribute', mask_size) else: self.psamask = PSAMask(psa_type, mask_size) self.proj = ConvModule( self.channels * (2 if psa_type == 'bi-direction' else 1), self.in_channels, kernel_size=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( self.in_channels * 2, self.channels, kernel_size=3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) identity = x align_corners = self.align_corners if self.psa_type in ['collect', 'distribute']: out = self.reduce(x) n, c, h, w = out.size() if self.shrink_factor != 1: if h % self.shrink_factor and w % self.shrink_factor: h = (h - 1) // self.shrink_factor + 1 w = (w - 1) // self.shrink_factor + 1 align_corners = True else: h = h // self.shrink_factor w = w // self.shrink_factor align_corners = False out = resize( out, size=(h, w), mode='bilinear', align_corners=align_corners) y = self.attention(out) if self.compact: if self.psa_type == 'collect': y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y = self.psamask(y) if self.psa_softmax: y = F.softmax(y, dim=1) out = torch.bmm( out.view(n, c, h * w), y.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) else: x_col = self.reduce(x) x_dis = self.reduce_p(x) n, c, h, w = x_col.size() if self.shrink_factor != 1: if h % self.shrink_factor and w % self.shrink_factor: h = (h - 1) // self.shrink_factor + 1 w = (w - 1) // self.shrink_factor + 1 align_corners = True else: h = h // self.shrink_factor w = w // self.shrink_factor align_corners = False x_col = resize( x_col, size=(h, w), mode='bilinear', align_corners=align_corners) x_dis = resize( x_dis, size=(h, w), mode='bilinear', align_corners=align_corners) y_col = self.attention(x_col) y_dis = self.attention_p(x_dis) if self.compact: y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y_col = self.psamask_collect(y_col) y_dis = self.psamask_distribute(y_dis) if self.psa_softmax: y_col = F.softmax(y_col, dim=1) y_dis = F.softmax(y_dis, dim=1) x_col = torch.bmm( x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) x_dis = torch.bmm( x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( n, c, h, w) * (1.0 / self.normalization_factor) out = torch.cat([x_col, x_dis], 1) out = self.proj(out) out = resize( out, size=identity.shape[2:], mode='bilinear', align_corners=align_corners) out = self.bottleneck(torch.cat((identity, out), dim=1)) out = self.cls_seg(out) return out ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/psp_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead class PPM(nn.ModuleList): """Pooling Pyramid Module used in PSPNet. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. align_corners (bool): align_corners argument of F.interpolate. """ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_cfg, align_corners): super(PPM, self).__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg for pool_scale in pool_scales: self.append( nn.Sequential( nn.AdaptiveAvgPool2d(pool_scale), ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))) def forward(self, x): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) upsampled_ppm_out = resize( ppm_out, size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ppm_outs.append(upsampled_ppm_out) return ppm_outs @HEADS.register_module() class PSPHead(BaseDecodeHead): """Pyramid Scene Parsing Network. This head is the implementation of `PSPNet `_. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. Default: (1, 2, 3, 6). """ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): super(PSPHead, self).__init__(**kwargs) assert isinstance(pool_scales, (list, tuple)) self.pool_scales = pool_scales self.psp_modules = PPM( self.pool_scales, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) self.bottleneck = ConvModule( self.in_channels + len(pool_scales) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) psp_outs = [x] psp_outs.extend(self.psp_modules(x)) psp_outs = torch.cat(psp_outs, dim=1) output = self.bottleneck(psp_outs) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_aspp_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmseg.ops import resize from ..builder import HEADS from .aspp_head import ASPPHead, ASPPModule class DepthwiseSeparableASPPModule(ASPPModule): """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable conv.""" def __init__(self, **kwargs): super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) for i, dilation in enumerate(self.dilations): if dilation > 1: self[i] = DepthwiseSeparableConvModule( self.in_channels, self.channels, 3, dilation=dilation, padding=dilation, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) @HEADS.register_module() class DepthwiseSeparableASPPHead(ASPPHead): """Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. This head is the implementation of `DeepLabV3+ `_. Args: c1_in_channels (int): The input channels of c1 decoder. If is 0, the no decoder will be used. c1_channels (int): The intermediate channels of c1 decoder. """ def __init__(self, c1_in_channels, c1_channels, **kwargs): super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) assert c1_in_channels >= 0 self.aspp_modules = DepthwiseSeparableASPPModule( dilations=self.dilations, in_channels=self.in_channels, channels=self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if c1_in_channels > 0: self.c1_bottleneck = ConvModule( c1_in_channels, c1_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) else: self.c1_bottleneck = None self.sep_bottleneck = nn.Sequential( DepthwiseSeparableConvModule( self.channels + c1_channels, self.channels, 3, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), DepthwiseSeparableConvModule( self.channels, self.channels, 3, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) if self.c1_bottleneck is not None: c1_output = self.c1_bottleneck(inputs[0]) output = resize( input=output, size=c1_output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = torch.cat([output, c1_output], dim=1) output = self.sep_bottleneck(output) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/sep_fcn_head.py ================================================ from mmcv.cnn import DepthwiseSeparableConvModule from ..builder import HEADS from .fcn_head import FCNHead @HEADS.register_module() class DepthwiseSeparableFCNHead(FCNHead): """Depthwise-Separable Fully Convolutional Network for Semantic Segmentation. This head is implemented according to Fast-SCNN paper. Args: in_channels(int): Number of output channels of FFM. channels(int): Number of middle-stage channels in the decode head. concat_input(bool): Whether to concatenate original decode input into the result of several consecutive convolution layers. Default: True. num_classes(int): Used to determine the dimension of final prediction tensor. in_index(int): Correspond with 'out_indices' in FastSCNN backbone. norm_cfg (dict | None): Config of norm layers. align_corners (bool): align_corners argument of F.interpolate. Default: False. loss_decode(dict): Config of loss type and some relevant additional options. """ def __init__(self, **kwargs): super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) self.convs[0] = DepthwiseSeparableConvModule( self.in_channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg) for i in range(1, self.num_convs): self.convs[i] = DepthwiseSeparableConvModule( self.channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg) if self.concat_input: self.conv_cat = DepthwiseSeparableConvModule( self.in_channels + self.channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/decode_heads/uper_head.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead from .psp_head import PPM @HEADS.register_module() class UPerHead(BaseDecodeHead): """Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet `_. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). """ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): super(UPerHead, self).__init__( input_transform='multiple_select', **kwargs) # PSP Module self.psp_modules = PPM( pool_scales, self.in_channels[-1], self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) self.bottleneck = ConvModule( self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) # FPN Module self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the top layer l_conv = ConvModule( in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, inplace=False) fpn_conv = ConvModule( self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) self.fpn_bottleneck = ConvModule( len(self.in_channels) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def psp_forward(self, inputs): """Forward function of PSP module.""" x = inputs[-1] psp_outs = [x] psp_outs.extend(self.psp_modules(x)) psp_outs = torch.cat(psp_outs, dim=1) output = self.bottleneck(psp_outs) return output def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] laterals.append(self.psp_forward(inputs)) # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] += resize( laterals[i], size=prev_shape, mode='bilinear', align_corners=self.align_corners) # build outputs fpn_outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) for i in range(used_backbone_levels - 1, 0, -1): fpn_outs[i] = resize( fpn_outs[i], size=fpn_outs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) fpn_outs = torch.cat(fpn_outs, dim=1) output = self.fpn_bottleneck(fpn_outs) output = self.cls_seg(output) return output ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/__init__.py ================================================ from .accuracy import Accuracy, accuracy from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy) from .lovasz_loss import LovaszLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/accuracy.py ================================================ import torch.nn as nn def accuracy(pred, target, topk=1, thresh=None): """Calculate accuracy according to the prediction and target. Args: pred (torch.Tensor): The model prediction, shape (N, num_class, ...) target (torch.Tensor): The target of each prediction, shape (N, , ...) topk (int | tuple[int], optional): If the predictions in ``topk`` matches the target, the predictions will be regarded as correct ones. Defaults to 1. thresh (float, optional): If not None, predictions with scores under this threshold are considered incorrect. Default to None. Returns: float | tuple[float]: If the input ``topk`` is a single integer, the function will return a single float as accuracy. If ``topk`` is a tuple containing multiple integers, the function will return a tuple containing accuracies of each ``topk`` number. """ assert isinstance(topk, (int, tuple)) if isinstance(topk, int): topk = (topk, ) return_single = True else: return_single = False maxk = max(topk) if pred.size(0) == 0: accu = [pred.new_tensor(0.) for i in range(len(topk))] return accu[0] if return_single else accu assert pred.ndim == target.ndim + 1 assert pred.size(0) == target.size(0) assert maxk <= pred.size(1), \ f'maxk {maxk} exceeds pred dimension {pred.size(1)}' pred_value, pred_label = pred.topk(maxk, dim=1) # transpose to shape (maxk, N, ...) pred_label = pred_label.transpose(0, 1) correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) if thresh is not None: # Only prediction values larger than thresh are counted as correct correct = correct & (pred_value > thresh).t() res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / target.numel())) return res[0] if return_single else res class Accuracy(nn.Module): """Accuracy calculation module.""" def __init__(self, topk=(1, ), thresh=None): """Module to calculate the accuracy. Args: topk (tuple, optional): The criterion used to calculate the accuracy. Defaults to (1,). thresh (float, optional): If not None, predictions with scores under this threshold are considered incorrect. Default to None. """ super().__init__() self.topk = topk self.thresh = thresh def forward(self, pred, target): """Forward function to calculate accuracy. Args: pred (torch.Tensor): Prediction of models. target (torch.Tensor): Target for each prediction. Returns: tuple[float]: The accuracies under different topk criterions. """ return accuracy(pred, target, self.topk, self.thresh) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/cross_entropy_loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES from .utils import weight_reduce_loss def cross_entropy(pred, label, weight=None, class_weight=None, reduction='mean', avg_factor=None, ignore_index=-100): """The wrapper function for :func:`F.cross_entropy`""" # class_weight is a manual rescaling weight given to each class. # If given, has to be a Tensor of size C element-wise losses loss = F.cross_entropy( pred, label, weight=class_weight, reduction='none', ignore_index=ignore_index) # apply weights and do the reduction if weight is not None: weight = weight.float() loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): """Expand onehot labels to match the size of prediction.""" bin_labels = labels.new_zeros(target_shape) valid_mask = (labels >= 0) & (labels != ignore_index) inds = torch.nonzero(valid_mask, as_tuple=True) if inds[0].numel() > 0: if labels.dim() == 3: bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 else: bin_labels[inds[0], labels[valid_mask]] = 1 valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() if label_weights is None: bin_label_weights = valid_mask else: bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) bin_label_weights *= valid_mask return bin_labels, bin_label_weights def binary_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None, ignore_index=255): """Calculate the binary CrossEntropy loss. Args: pred (torch.Tensor): The prediction with shape (N, 1). label (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. ignore_index (int | None): The label index to be ignored. Default: 255 Returns: torch.Tensor: The calculated loss """ if pred.dim() != label.dim(): assert (pred.dim() == 2 and label.dim() == 1) or ( pred.dim() == 4 and label.dim() == 3), \ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ 'H, W], label shape [N, H, W] are supported' label, weight = _expand_onehot_labels(label, weight, pred.shape, ignore_index) # weighted element-wise losses if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss loss = weight_reduce_loss( loss, weight, reduction=reduction, avg_factor=avg_factor) return loss def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None, class_weight=None, ignore_index=None): """Calculate the CrossEntropy loss for masks. Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. target (torch.Tensor): The learning label of the prediction. label (torch.Tensor): ``label`` indicates the class label of the mask' corresponding object. This will be used to select the mask in the of the class which the object belongs to when the mask prediction if not class-agnostic. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. ignore_index (None): Placeholder, to be consistent with other loss. Default: None. Returns: torch.Tensor: The calculated loss """ assert ignore_index is None, 'BCE loss does not support ignore_index' # TODO: handle these two reserved arguments assert reduction == 'mean' and avg_factor is None num_rois = pred.size()[0] inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) pred_slice = pred[inds, label].squeeze(1) return F.binary_cross_entropy_with_logits( pred_slice, target, weight=class_weight, reduction='mean')[None] @LOSSES.register_module() class CrossEntropyLoss(nn.Module): """CrossEntropyLoss. Args: use_sigmoid (bool, optional): Whether the prediction uses sigmoid of softmax. Defaults to False. use_mask (bool, optional): Whether to use mask cross entropy loss. Defaults to False. reduction (str, optional): . Defaults to 'mean'. Options are "none", "mean" and "sum". class_weight (list[float], optional): Weight of each class. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. """ def __init__(self, use_sigmoid=False, use_mask=False, reduction='mean', class_weight=None, loss_weight=1.0): super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid self.use_mask = use_mask self.reduction = reduction self.loss_weight = loss_weight self.class_weight = class_weight if self.use_sigmoid: self.cls_criterion = binary_cross_entropy elif self.use_mask: self.cls_criterion = mask_cross_entropy else: self.cls_criterion = cross_entropy def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function.""" assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: class_weight = None loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/lovasz_loss.py ================================================ """Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)""" import mmcv import torch import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES from .utils import weight_reduce_loss def lovasz_grad(gt_sorted): """Computes gradient of the Lovasz extension w.r.t sorted errors. See Alg. 1 in paper. """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard def flatten_binary_logits(logits, labels, ignore_index=None): """Flattens predictions in the batch (binary case) Remove labels equal to 'ignore_index'.""" logits = logits.view(-1) labels = labels.view(-1) if ignore_index is None: return logits, labels valid = (labels != ignore_index) vlogits = logits[valid] vlabels = labels[valid] return vlogits, vlabels def flatten_probs(probs, labels, ignore_index=None): """Flattens predictions in the batch.""" if probs.dim() == 3: # assumes output of a sigmoid layer B, H, W = probs.size() probs = probs.view(B, 1, H, W) B, C, H, W = probs.size() probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C labels = labels.view(-1) if ignore_index is None: return probs, labels valid = (labels != ignore_index) vprobs = probs[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobs, vlabels def lovasz_hinge_flat(logits, labels): """Binary Lovasz hinge loss. Args: logits (torch.Tensor): [P], logits at each prediction (between -infty and +infty). labels (torch.Tensor): [P], binary ground truth labels (0 or 1). Returns: torch.Tensor: The calculated loss. """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * signs) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = lovasz_grad(gt_sorted) loss = torch.dot(F.relu(errors_sorted), grad) return loss def lovasz_hinge(logits, labels, classes='present', per_image=False, class_weight=None, reduction='mean', avg_factor=None, ignore_index=255): """Binary Lovasz hinge loss. Args: logits (torch.Tensor): [B, H, W], logits at each pixel (between -infty and +infty). labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). classes (str | list[int], optional): Placeholder, to be consistent with other loss. Default: None. per_image (bool, optional): If per_image is True, compute the loss per image instead of per batch. Default: False. class_weight (list[float], optional): Placeholder, to be consistent with other loss. Default: None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. This parameter only works when per_image is True. Default: None. ignore_index (int | None): The label index to be ignored. Default: 255. Returns: torch.Tensor: The calculated loss. """ if per_image: loss = [ lovasz_hinge_flat(*flatten_binary_logits( logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) for logit, label in zip(logits, labels) ] loss = weight_reduce_loss( torch.stack(loss), None, reduction, avg_factor) else: loss = lovasz_hinge_flat( *flatten_binary_logits(logits, labels, ignore_index)) return loss def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): """Multi-class Lovasz-Softmax loss. Args: probs (torch.Tensor): [P, C], class probabilities at each prediction (between 0 and 1). labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). classes (str | list[int], optional): Classes choosed to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Default: 'present'. class_weight (list[float], optional): The weight for each class. Default: None. Returns: torch.Tensor: The calculated loss. """ if probs.numel() == 0: # only void pixels, the gradients should be 0 return probs * 0. C = probs.size(1) losses = [] class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c if (classes == 'present' and fg.sum() == 0): continue if C == 1: if len(classes) > 1: raise ValueError('Sigmoid output possible only with 1 class') class_pred = probs[:, 0] else: class_pred = probs[:, c] errors = (fg - class_pred).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) if class_weight is not None: loss *= class_weight[c] losses.append(loss) return torch.stack(losses).mean() def lovasz_softmax(probs, labels, classes='present', per_image=False, class_weight=None, reduction='mean', avg_factor=None, ignore_index=255): """Multi-class Lovasz-Softmax loss. Args: probs (torch.Tensor): [B, C, H, W], class probabilities at each prediction (between 0 and 1). labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and C - 1). classes (str | list[int], optional): Classes choosed to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Default: 'present'. per_image (bool, optional): If per_image is True, compute the loss per image instead of per batch. Default: False. class_weight (list[float], optional): The weight for each class. Default: None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. This parameter only works when per_image is True. Default: None. ignore_index (int | None): The label index to be ignored. Default: 255. Returns: torch.Tensor: The calculated loss. """ if per_image: loss = [ lovasz_softmax_flat( *flatten_probs( prob.unsqueeze(0), label.unsqueeze(0), ignore_index), classes=classes, class_weight=class_weight) for prob, label in zip(probs, labels) ] loss = weight_reduce_loss( torch.stack(loss), None, reduction, avg_factor) else: loss = lovasz_softmax_flat( *flatten_probs(probs, labels, ignore_index), classes=classes, class_weight=class_weight) return loss @LOSSES.register_module() class LovaszLoss(nn.Module): """LovaszLoss. This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks `_. Args: loss_type (str, optional): Binary or multi-class loss. Default: 'multi_class'. Options are "binary" and "multi_class". classes (str | list[int], optional): Classes choosed to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Default: 'present'. per_image (bool, optional): If per_image is True, compute the loss per image instead of per batch. Default: False. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. class_weight (list[float], optional): The weight for each class. Default: None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. """ def __init__(self, loss_type='multi_class', classes='present', per_image=False, reduction='mean', class_weight=None, loss_weight=1.0): super(LovaszLoss, self).__init__() assert loss_type in ('binary', 'multi_class'), "loss_type should be \ 'binary' or 'multi_class'." if loss_type == 'binary': self.cls_criterion = lovasz_hinge else: self.cls_criterion = lovasz_softmax assert classes in ('all', 'present') or mmcv.is_list_of(classes, int) if not per_image: assert reduction == 'none', "reduction should be 'none' when \ per_image is False." self.classes = classes self.per_image = per_image self.reduction = reduction self.loss_weight = loss_weight self.class_weight = class_weight def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function.""" assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: class_weight = None # if multi-class loss, transform logits to probs if self.cls_criterion == lovasz_softmax: cls_score = F.softmax(cls_score, dim=1) loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, self.classes, self.per_image, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/losses/utils.py ================================================ import functools import torch.nn.functional as F def reduce_loss(loss, reduction): """Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Return: Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss elif reduction_enum == 1: return loss.mean() elif reduction_enum == 2: return loss.sum() def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): """Apply element-wise weight and reduce loss. Args: loss (Tensor): Element-wise loss. weight (Tensor): Element-wise weights. reduction (str): Same as built-in losses of PyTorch. avg_factor (float): Avarage factor when computing the mean of losses. Returns: Tensor: Processed loss values. """ # if weight is specified, apply element-wise weight if weight is not None: assert weight.dim() == loss.dim() if weight.dim() > 1: assert weight.size(1) == 1 or weight.size(1) == loss.size(1) loss = loss * weight # if avg_factor is not specified, just reduce the loss if avg_factor is None: loss = reduce_loss(loss, reduction) else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': loss = loss.sum() / avg_factor # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') return loss def weighted_loss(loss_func): """Create a weighted version of a given loss function. To use this decorator, the loss function must have the signature like `loss_func(pred, target, **kwargs)`. The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like `loss_func(pred, target, weight=None, reduction='mean', avg_factor=None, **kwargs)`. :Example: >>> import torch >>> @weighted_loss >>> def l1_loss(pred, target): >>> return (pred - target).abs() >>> pred = torch.Tensor([0, 2, 3]) >>> target = torch.Tensor([1, 1, 1]) >>> weight = torch.Tensor([1, 0, 1]) >>> l1_loss(pred, target) tensor(1.3333) >>> l1_loss(pred, target, weight) tensor(1.) >>> l1_loss(pred, target, reduction='none') tensor([1., 1., 2.]) >>> l1_loss(pred, target, weight, avg_factor=2) tensor(1.5000) """ @functools.wraps(loss_func) def wrapper(pred, target, weight=None, reduction='mean', avg_factor=None, **kwargs): # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss return wrapper ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/necks/__init__.py ================================================ from .fpn import FPN __all__ = ['FPN'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/necks/fpn.py ================================================ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, xavier_init from ..builder import NECKS @NECKS.register_module() class FPN(nn.Module): """Feature Pyramid Network. This is an implementation of - Feature Pyramid Networks for Object Detection (https://arxiv.org/abs/1612.03144) Args: in_channels (List[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale) num_outs (int): Number of output scales. start_level (int): Index of the start input backbone level used to build the feature pyramid. Default: 0. end_level (int): Index of the end input backbone level (exclusive) to build the feature pyramid. Default: -1, which means the last level. add_extra_convs (bool | str): If bool, it decides whether to add conv layers on top of the original feature maps. Default to False. If True, its actual mode is specified by `extra_convs_on_inputs`. If str, it specifies the source feature map of the extra convs. Only the following options are allowed - 'on_input': Last feat map of neck inputs (i.e. backbone feature). - 'on_lateral': Last feature map after lateral convs. - 'on_output': The last output feature map after fpn convs. extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs on the original feature from the backbone. If True, it is equivalent to `add_extra_convs='on_input'`. If False, it is equivalent to set `add_extra_convs='on_output'`. Default to True. relu_before_extra_convs (bool): Whether to apply relu before the extra conv. Default: False. no_norm_on_lateral (bool): Whether to apply norm on lateral. Default: False. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (str): Config dict for activation layer in ConvModule. Default: None. upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(mode='nearest')` Example: >>> import torch >>> in_channels = [2, 3, 5, 7] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = FPN(in_channels, 11, len(in_channels)).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 11, 340, 340]) outputs[1].shape = torch.Size([1, 11, 170, 170]) outputs[2].shape = torch.Size([1, 11, 84, 84]) outputs[3].shape = torch.Size([1, 11, 43, 43]) """ def __init__(self, in_channels, out_channels, num_outs, start_level=0, end_level=-1, add_extra_convs=False, extra_convs_on_inputs=False, relu_before_extra_convs=False, no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, act_cfg=None, upsample_cfg=dict(mode='nearest')): super(FPN, self).__init__() assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.relu_before_extra_convs = relu_before_extra_convs self.no_norm_on_lateral = no_norm_on_lateral self.fp16_enabled = False self.upsample_cfg = upsample_cfg.copy() if end_level == -1: self.backbone_end_level = self.num_ins assert num_outs >= self.num_ins - start_level else: # if end_level < inputs, no extra level is allowed self.backbone_end_level = end_level assert end_level <= len(in_channels) assert num_outs == end_level - start_level self.start_level = start_level self.end_level = end_level self.add_extra_convs = add_extra_convs assert isinstance(add_extra_convs, (str, bool)) if isinstance(add_extra_convs, str): # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') elif add_extra_convs: # True if extra_convs_on_inputs: # For compatibility with previous release # TODO: deprecate `extra_convs_on_inputs` self.add_extra_convs = 'on_input' else: self.add_extra_convs = 'on_output' self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, act_cfg=act_cfg, inplace=False) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) # add extra conv layers (e.g., RetinaNet) extra_levels = num_outs - self.backbone_end_level + self.start_level if self.add_extra_convs and extra_levels >= 1: for i in range(extra_levels): if i == 0 and self.add_extra_convs == 'on_input': in_channels = self.in_channels[self.backbone_end_level - 1] else: in_channels = out_channels extra_fpn_conv = ConvModule( in_channels, out_channels, 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.fpn_convs.append(extra_fpn_conv) # default init_weights for conv(msra) and norm in ConvModule def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform') def forward(self, inputs): assert len(inputs) == len(self.in_channels) # build laterals laterals = [ lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): # In some cases, fixing `scale factor` (e.g. 2) is preferred, but # it cannot co-exist with `size` in `F.interpolate`. if 'scale_factor' in self.upsample_cfg: laterals[i - 1] += F.interpolate(laterals[i], **self.upsample_cfg) else: prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] += F.interpolate( laterals[i], size=prev_shape, **self.upsample_cfg) # build outputs # part 1: from original levels outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) ] # part 2: add extra levels if self.num_outs > len(outs): # use max pool to get more levels on top of outputs # (e.g., Faster R-CNN, Mask R-CNN) if not self.add_extra_convs: for i in range(self.num_outs - used_backbone_levels): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: if self.add_extra_convs == 'on_input': extra_source = inputs[self.backbone_end_level - 1] elif self.add_extra_convs == 'on_lateral': extra_source = laterals[-1] elif self.add_extra_convs == 'on_output': extra_source = outs[-1] else: raise NotImplementedError outs.append(self.fpn_convs[used_backbone_levels](extra_source)) for i in range(used_backbone_levels + 1, self.num_outs): if self.relu_before_extra_convs: outs.append(self.fpn_convs[i](F.relu(outs[-1]))) else: outs.append(self.fpn_convs[i](outs[-1])) return tuple(outs) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/__init__.py ================================================ from .cascade_encoder_decoder import CascadeEncoderDecoder from .encoder_decoder import EncoderDecoder __all__ = ['EncoderDecoder', 'CascadeEncoderDecoder'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/base.py ================================================ import logging import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict import mmcv import numpy as np import torch import torch.distributed as dist import torch.nn as nn from mmcv.runner import auto_fp16 class BaseSegmentor(nn.Module): """Base class for segmentors.""" __metaclass__ = ABCMeta def __init__(self): super(BaseSegmentor, self).__init__() self.fp16_enabled = False @property def with_neck(self): """bool: whether the segmentor has neck""" return hasattr(self, 'neck') and self.neck is not None @property def with_auxiliary_head(self): """bool: whether the segmentor has auxiliary head""" return hasattr(self, 'auxiliary_head') and self.auxiliary_head is not None @property def with_decode_head(self): """bool: whether the segmentor has decode head""" return hasattr(self, 'decode_head') and self.decode_head is not None @abstractmethod def extract_feat(self, imgs): """Placeholder for extract features from images.""" pass @abstractmethod def encode_decode(self, img, img_metas): """Placeholder for encode images with backbone and decode into a semantic segmentation map of the same size as input.""" pass @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): """Placeholder for Forward function for training.""" pass @abstractmethod def simple_test(self, img, img_meta, **kwargs): """Placeholder for single image test.""" pass @abstractmethod def aug_test(self, imgs, img_metas, **kwargs): """Placeholder for augmentation test.""" pass def init_weights(self, pretrained=None): """Initialize the weights in segmentor. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if pretrained is not None: logger = logging.getLogger() logger.info(f'load model from: {pretrained}') def forward_test(self, imgs, img_metas, **kwargs): """ Args: imgs (List[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. img_metas (List[List[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. """ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError(f'{name} must be a list, but got ' f'{type(var)}') num_augs = len(imgs) if num_augs != len(img_metas): raise ValueError(f'num of augmentations ({len(imgs)}) != ' f'num of image meta ({len(img_metas)})') # all images in the same aug batch all of the same ori_shape and pad # shape for img_meta in img_metas: ori_shapes = [_['ori_shape'] for _ in img_meta] assert all(shape == ori_shapes[0] for shape in ori_shapes) img_shapes = [_['img_shape'] for _ in img_meta] assert all(shape == img_shapes[0] for shape in img_shapes) pad_shapes = [_['pad_shape'] for _ in img_meta] assert all(shape == pad_shapes[0] for shape in pad_shapes) if num_augs == 1: return self.simple_test(imgs[0], img_metas[0], **kwargs) else: return self.aug_test(imgs, img_metas, **kwargs) @auto_fp16(apply_to=('img', )) def forward(self, img, img_metas, return_loss=True, **kwargs): """Calls either :func:`forward_train` or :func:`forward_test` depending on whether ``return_loss`` is ``True``. Note this setting will change the expected inputs. When ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor and List[dict]), and when ``resturn_loss=False``, img and img_meta should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ if return_loss: return self.forward_train(img, img_metas, **kwargs) else: return self.forward_test(img, img_metas, **kwargs) def train_step(self, data_batch, optimizer, **kwargs): """The iteration step during training. This method defines an iteration step during training, except for the back propagation and optimizer updating, which are done in an optimizer hook. Note that in some complicated cases or models, the whole process including back propagation and optimizer updating is also defined in this method, such as GAN. Args: data (dict): The output of dataloader. optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of runner is passed to ``train_step()``. This argument is unused and reserved. Returns: dict: It should contain at least 3 keys: ``loss``, ``log_vars``, ``num_samples``. ``loss`` is a tensor for back propagation, which can be a weighted sum of multiple losses. ``log_vars`` contains all the variables to be sent to the logger. ``num_samples`` indicates the batch size (when the model is DDP, it means the batch size on each GPU), which is used for averaging the logs. """ losses = self(**data_batch) loss, log_vars = self._parse_losses(losses) outputs = dict( loss=loss, log_vars=log_vars, num_samples=len(data_batch['img'].data)) return outputs def val_step(self, data_batch, **kwargs): """The iteration step during validation. This method shares the same signature as :func:`train_step`, but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook. """ output = self(**data_batch, **kwargs) return output @staticmethod def _parse_losses(losses): """Parse the raw outputs (losses) of the network. Args: losses (dict): Raw output of the network, which usually contain losses and other necessary information. Returns: tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor which may be a weighted sum of all losses, log_vars contains all the variables to be sent to the logger. """ log_vars = OrderedDict() for loss_name, loss_value in losses.items(): if isinstance(loss_value, torch.Tensor): log_vars[loss_name] = loss_value.mean() elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) else: raise TypeError( f'{loss_name} is not a tensor or list of tensors') loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) log_vars['loss'] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training if dist.is_available() and dist.is_initialized(): loss_value = loss_value.data.clone() dist.all_reduce(loss_value.div_(dist.get_world_size())) log_vars[loss_name] = loss_value.item() return loss, log_vars def show_result(self, img, result, palette=None, win_name='', show=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (Tensor): The semantic segmentation results to draw over `img`. palette (list[list[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file` """ img = mmcv.imread(img) img = img.copy() seg = result[0] if palette is None: if self.PALETTE is None: palette = np.random.randint( 0, 255, size=(len(self.CLASSES), 3)) else: palette = self.PALETTE palette = np.array(palette) assert palette.shape[0] == len(self.CLASSES) assert palette.shape[1] == 3 assert len(palette.shape) == 2 color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) for label, color in enumerate(palette): color_seg[seg == label, :] = color # convert to BGR color_seg = color_seg[..., ::-1] img = img * 0.5 + color_seg * 0.5 img = img.astype(np.uint8) # if out_file specified, do not show image in window if out_file is not None: show = False if show: mmcv.imshow(img, win_name, wait_time) if out_file is not None: mmcv.imwrite(img, out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/cascade_encoder_decoder.py ================================================ from torch import nn from mmseg.core import add_prefix from mmseg.ops import resize from .. import builder from ..builder import SEGMENTORS from .encoder_decoder import EncoderDecoder @SEGMENTORS.register_module() class CascadeEncoderDecoder(EncoderDecoder): """Cascade Encoder Decoder segmentors. CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of CascadeEncoderDecoder are cascaded. The output of previous decoder_head will be the input of next decoder_head. """ def __init__(self, num_stages, backbone, decode_head, neck=None, auxiliary_head=None, train_cfg=None, test_cfg=None, pretrained=None): self.num_stages = num_stages super(CascadeEncoderDecoder, self).__init__( backbone=backbone, decode_head=decode_head, neck=neck, auxiliary_head=auxiliary_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained) def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" assert isinstance(decode_head, list) assert len(decode_head) == self.num_stages self.decode_head = nn.ModuleList() for i in range(self.num_stages): self.decode_head.append(builder.build_head(decode_head[i])) self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes def init_weights(self, pretrained=None): """Initialize the weights in backbone and heads. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ self.backbone.init_weights(pretrained=pretrained) for i in range(self.num_stages): self.decode_head[i].init_weights() if self.with_auxiliary_head: if isinstance(self.auxiliary_head, nn.ModuleList): for aux_head in self.auxiliary_head: aux_head.init_weights() else: self.auxiliary_head.init_weights() def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(img) out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) for i in range(1, self.num_stages): out = self.decode_head[i].forward_test(x, out, img_metas, self.test_cfg) out = resize( input=out, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners) return out def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head[0].forward_train( x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode_0')) for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. prev_outputs = self.decode_head[i - 1].forward_test( x, img_metas, self.test_cfg) loss_decode = self.decode_head[i].forward_train( x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, f'decode_{i}')) return losses ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/segmentors/encoder_decoder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmseg.core import add_prefix from mmseg.ops import resize from .. import builder from ..builder import SEGMENTORS from .base import BaseSegmentor @SEGMENTORS.register_module() class EncoderDecoder(BaseSegmentor): """Encoder Decoder segmentors. EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. Note that auxiliary_head is only used for deep supervision during training, which could be dumped during inference. """ def __init__(self, backbone, decode_head, neck=None, auxiliary_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(EncoderDecoder, self).__init__() self.backbone = builder.build_backbone(backbone) if neck is not None: self.neck = builder.build_neck(neck) self._init_decode_head(decode_head) self._init_auxiliary_head(auxiliary_head) self.train_cfg = train_cfg self.test_cfg = test_cfg self.init_weights(pretrained=pretrained) assert self.with_decode_head def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" self.decode_head = builder.build_head(decode_head) self.align_corners = self.decode_head.align_corners self.num_classes = self.decode_head.num_classes def _init_auxiliary_head(self, auxiliary_head): """Initialize ``auxiliary_head``""" if auxiliary_head is not None: if isinstance(auxiliary_head, list): self.auxiliary_head = nn.ModuleList() for head_cfg in auxiliary_head: self.auxiliary_head.append(builder.build_head(head_cfg)) else: self.auxiliary_head = builder.build_head(auxiliary_head) def init_weights(self, pretrained=None): """Initialize the weights in backbone and heads. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ super(EncoderDecoder, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) self.decode_head.init_weights() if self.with_auxiliary_head: if isinstance(self.auxiliary_head, nn.ModuleList): for aux_head in self.auxiliary_head: aux_head.init_weights() else: self.auxiliary_head.init_weights() def extract_feat(self, img): """Extract features from images.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(img) out = self._decode_head_forward_test(x, img_metas) out = resize( input=out, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners) return out def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode')) return losses def _decode_head_forward_test(self, x, img_metas): """Run forward function and calculate loss for decode head in inference.""" seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) return seg_logits def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for auxiliary head in training.""" losses = dict() if isinstance(self.auxiliary_head, nn.ModuleList): for idx, aux_head in enumerate(self.auxiliary_head): loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_aux, f'aux_{idx}')) else: loss_aux = self.auxiliary_head.forward_train( x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_aux, 'aux')) return losses def forward_dummy(self, img): """Dummy forward function.""" seg_logit = self.encode_decode(img, None) return seg_logit def forward_train(self, img, img_metas, gt_semantic_seg): """Forward function for training. Args: img (Tensor): Input images. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) losses = dict() loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg) losses.update(loss_decode) if self.with_auxiliary_head: loss_aux = self._auxiliary_head_forward_train( x, img_metas, gt_semantic_seg) losses.update(loss_aux) return losses # TODO refactor def slide_inference(self, img, img_meta, rescale): """Inference by sliding-window with overlap. If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding. """ h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() num_classes = self.num_classes h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logit = self.encode_decode(crop_img, img_meta) preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 if torch.onnx.is_in_onnx_export(): # cast count_mat to constant while exporting to ONNX count_mat = torch.from_numpy( count_mat.cpu().detach().numpy()).to(device=img.device) preds = preds / count_mat if rescale: preds = resize( preds, size=img_meta[0]['ori_shape'][:2], mode='bilinear', align_corners=self.align_corners, warning=False) return preds def whole_inference(self, img, img_meta, rescale): """Inference with full image.""" seg_logit = self.encode_decode(img, img_meta) if rescale: seg_logit = resize( seg_logit, size=img_meta[0]['ori_shape'][:2], mode='bilinear', align_corners=self.align_corners, warning=False) return seg_logit def inference(self, img, img_meta, rescale): """Inference with slide/whole style. Args: img (Tensor): The input image of shape (N, 3, H, W). img_meta (dict): Image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. rescale (bool): Whether rescale back to original shape. Returns: Tensor: The output segmentation map. """ assert self.test_cfg.mode in ['slide', 'whole'] ori_shape = img_meta[0]['ori_shape'] assert all(_['ori_shape'] == ori_shape for _ in img_meta) if self.test_cfg.mode == 'slide': seg_logit = self.slide_inference(img, img_meta, rescale) else: seg_logit = self.whole_inference(img, img_meta, rescale) output = F.softmax(seg_logit, dim=1) flip = img_meta[0]['flip'] if flip: flip_direction = img_meta[0]['flip_direction'] assert flip_direction in ['horizontal', 'vertical'] if flip_direction == 'horizontal': output = output.flip(dims=(3, )) elif flip_direction == 'vertical': output = output.flip(dims=(2, )) return output def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) seg_pred = seg_logit.argmax(dim=1) if torch.onnx.is_in_onnx_export(): # our inference backend only support 4D output seg_pred = seg_pred.unsqueeze(0) return seg_pred seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) return seg_pred def aug_test(self, imgs, img_metas, rescale=True): """Test with augmentations. Only rescale=True is supported. """ # aug_test rescale all imgs back to ori_shape for now assert rescale # to save memory, we get augmented seg logit inplace seg_logit = self.inference(imgs[0], img_metas[0], rescale) for i in range(1, len(imgs)): cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) seg_logit += cur_seg_logit seg_logit /= len(imgs) seg_pred = seg_logit.argmax(dim=1) seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) return seg_pred ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/__init__.py ================================================ from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer from .self_attention_block import SelfAttentionBlock from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3' ] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/inverted_residual.py ================================================ from mmcv.cnn import ConvModule from torch import nn as nn from torch.utils import checkpoint as cp from .se_layer import SELayer class InvertedResidual(nn.Module): """InvertedResidual block for MobileNetV2. Args: in_channels (int): The input channels of the InvertedResidual block. out_channels (int): The output channels of the InvertedResidual block. stride (int): Stride of the middle (first) 3x3 convolution. expand_ratio (int): Adjusts number of channels of the hidden layer in InvertedResidual by this amount. dilation (int): Dilation rate of depthwise conv. Default: 1 conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU6'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: Tensor: The output tensor. """ def __init__(self, in_channels, out_channels, stride, expand_ratio, dilation=1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU6'), with_cp=False): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2], f'stride must in [1, 2]. ' \ f'But received {stride}.' self.with_cp = with_cp self.use_res_connect = self.stride == 1 and in_channels == out_channels hidden_dim = int(round(in_channels * expand_ratio)) layers = [] if expand_ratio != 1: layers.append( ConvModule( in_channels=in_channels, out_channels=hidden_dim, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) layers.extend([ ConvModule( in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=hidden_dim, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( in_channels=hidden_dim, out_channels=out_channels, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) ]) self.conv = nn.Sequential(*layers) def forward(self, x): def _inner_forward(x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class InvertedResidualV3(nn.Module): """Inverted Residual Block for MobileNetV3. Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. mid_channels (int): The input channels of the depthwise convolution. kernel_size (int): The kernal size of the depthwise convolution. Default: 3. stride (int): The stride of the depthwise convolution. Default: 1. se_cfg (dict): Config dict for se layer. Defaul: None, which means no se layer. with_expand_conv (bool): Use expand conv or not. If set False, mid_channels must be the same with in_channels. Default: True. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: Tensor: The output tensor. """ def __init__(self, in_channels, out_channels, mid_channels, kernel_size=3, stride=1, se_cfg=None, with_expand_conv=True, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_cp=False): super(InvertedResidualV3, self).__init__() self.with_res_shortcut = (stride == 1 and in_channels == out_channels) assert stride in [1, 2] self.with_cp = with_cp self.with_se = se_cfg is not None self.with_expand_conv = with_expand_conv if self.with_se: assert isinstance(se_cfg, dict) if not self.with_expand_conv: assert mid_channels == in_channels if self.with_expand_conv: self.expand_conv = ConvModule( in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.depthwise_conv = ConvModule( in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=mid_channels, conv_cfg=dict( type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) if self.with_se: self.se = SELayer(**se_cfg) self.linear_conv = ConvModule( in_channels=mid_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) def forward(self, x): def _inner_forward(x): out = x if self.with_expand_conv: out = self.expand_conv(out) out = self.depthwise_conv(out) if self.with_se: out = self.se(out) out = self.linear_conv(out) if self.with_res_shortcut: return x + out else: return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/make_divisible.py ================================================ def make_divisible(value, divisor, min_value=None, min_ratio=0.9): """Make divisible function. This function rounds the channel number to the nearest value that can be divisible by the divisor. It is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by divisor. It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa Args: value (int): The original channel number. divisor (int): The divisor to fully divide the channel number. min_value (int): The minimum value of the output channel. Default: None, means that the minimum value equal to the divisor. min_ratio (float): The minimum ratio of the rounded channel number to the original channel number. Default: 0.9. Returns: int: The modified output channel number. """ if min_value is None: min_value = divisor new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than (1-min_ratio). if new_value < min_ratio * value: new_value += divisor return new_value ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/res_layer.py ================================================ from mmcv.cnn import build_conv_layer, build_norm_layer from torch import nn as nn class ResLayer(nn.Sequential): """ResLayer to build ResNet style backbone. Args: block (nn.Module): block used to build ResLayer. inplanes (int): inplanes of block. planes (int): planes of block. num_blocks (int): number of blocks. stride (int): stride of the first block. Default: 1 avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False conv_cfg (dict): dictionary to construct and config conv layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') multi_grid (int | None): Multi grid dilation rates of last stage. Default: None contract_dilation (bool): Whether contract first dilation of each layer Default: False """ def __init__(self, block, inplanes, planes, num_blocks, stride=1, dilation=1, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), multi_grid=None, contract_dilation=False, **kwargs): self.block = block downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = [] conv_stride = stride if avg_down: conv_stride = 1 downsample.append( nn.AvgPool2d( kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)) downsample.extend([ build_conv_layer( conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=conv_stride, bias=False), build_norm_layer(norm_cfg, planes * block.expansion)[1] ]) downsample = nn.Sequential(*downsample) layers = [] if multi_grid is None: if dilation > 1 and contract_dilation: first_dilation = dilation // 2 else: first_dilation = dilation else: first_dilation = multi_grid[0] layers.append( block( inplanes=inplanes, planes=planes, stride=stride, dilation=first_dilation, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append( block( inplanes=inplanes, planes=planes, stride=1, dilation=dilation if multi_grid is None else multi_grid[i], conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) super(ResLayer, self).__init__(*layers) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/se_layer.py ================================================ import mmcv import torch.nn as nn from mmcv.cnn import ConvModule from .make_divisible import make_divisible class SELayer(nn.Module): """Squeeze-and-Excitation Module. Args: channels (int): The input (and output) channels of the SE layer. ratio (int): Squeeze ratio in SELayer, the intermediate channel will be ``int(channels/ratio)``. Default: 16. conv_cfg (None or dict): Config dict for convolution layer. Default: None, which means using conv2d. act_cfg (dict or Sequence[dict]): Config dict for activation layer. If act_cfg is a dict, two activation layers will be configurated by this dict. If act_cfg is a sequence of dicts, the first activation layer will be configurated by the first dict and the second activation layer will be configurated by the second dict. Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, divisor=6.0)). """ def __init__(self, channels, ratio=16, conv_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, divisor=6.0))): super(SELayer, self).__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert mmcv.is_tuple_of(act_cfg, dict) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.conv1 = ConvModule( in_channels=channels, out_channels=make_divisible(channels // ratio, 8), kernel_size=1, stride=1, conv_cfg=conv_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=make_divisible(channels // ratio, 8), out_channels=channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, act_cfg=act_cfg[1]) def forward(self, x): out = self.global_avgpool(x) out = self.conv1(out) out = self.conv2(out) return x * out ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/self_attention_block.py ================================================ import torch from mmcv.cnn import ConvModule, constant_init from torch import nn as nn from torch.nn import functional as F class SelfAttentionBlock(nn.Module): """General self-attention block/non-local block. Please refer to https://arxiv.org/abs/1706.03762 for details about key, query and value. Args: key_in_channels (int): Input channels of key feature. query_in_channels (int): Input channels of query feature. channels (int): Output channels of key/query transform. out_channels (int): Output channels. share_key_query (bool): Whether share projection weight between key and query projection. query_downsample (nn.Module): Query downsample module. key_downsample (nn.Module): Key downsample module. key_query_num_convs (int): Number of convs for key/query projection. value_num_convs (int): Number of convs for value projection. matmul_norm (bool): Whether normalize attention map with sqrt of channels with_out (bool): Whether use out projection. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, key_in_channels, query_in_channels, channels, out_channels, share_key_query, query_downsample, key_downsample, key_query_num_convs, value_out_num_convs, key_query_norm, value_out_norm, matmul_norm, with_out, conv_cfg, norm_cfg, act_cfg): super(SelfAttentionBlock, self).__init__() if share_key_query: assert key_in_channels == query_in_channels self.key_in_channels = key_in_channels self.query_in_channels = query_in_channels self.out_channels = out_channels self.channels = channels self.share_key_query = share_key_query self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.key_project = self.build_project( key_in_channels, channels, num_convs=key_query_num_convs, use_conv_module=key_query_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) if share_key_query: self.query_project = self.key_project else: self.query_project = self.build_project( query_in_channels, channels, num_convs=key_query_num_convs, use_conv_module=key_query_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.value_project = self.build_project( key_in_channels, channels if with_out else out_channels, num_convs=value_out_num_convs, use_conv_module=value_out_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) if with_out: self.out_project = self.build_project( channels, out_channels, num_convs=value_out_num_convs, use_conv_module=value_out_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) else: self.out_project = None self.query_downsample = query_downsample self.key_downsample = key_downsample self.matmul_norm = matmul_norm self.init_weights() def init_weights(self): """Initialize weight of later layer.""" if self.out_project is not None: if not isinstance(self.out_project, ConvModule): constant_init(self.out_project, 0) def build_project(self, in_channels, channels, num_convs, use_conv_module, conv_cfg, norm_cfg, act_cfg): """Build projection layer for key/query/value/out.""" if use_conv_module: convs = [ ConvModule( in_channels, channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) ] for _ in range(num_convs - 1): convs.append( ConvModule( channels, channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) else: convs = [nn.Conv2d(in_channels, channels, 1)] for _ in range(num_convs - 1): convs.append(nn.Conv2d(channels, channels, 1)) if len(convs) > 1: convs = nn.Sequential(*convs) else: convs = convs[0] return convs def forward(self, query_feats, key_feats): """Forward function.""" batch_size = query_feats.size(0) query = self.query_project(query_feats) if self.query_downsample is not None: query = self.query_downsample(query) query = query.reshape(*query.shape[:2], -1) query = query.permute(0, 2, 1).contiguous() key = self.key_project(key_feats) value = self.value_project(key_feats) if self.key_downsample is not None: key = self.key_downsample(key) value = self.key_downsample(value) key = key.reshape(*key.shape[:2], -1) value = value.reshape(*value.shape[:2], -1) value = value.permute(0, 2, 1).contiguous() sim_map = torch.matmul(query, key) if self.matmul_norm: sim_map = (self.channels**-.5) * sim_map sim_map = F.softmax(sim_map, dim=-1) context = torch.matmul(sim_map, value) context = context.permute(0, 2, 1).contiguous() context = context.reshape(batch_size, -1, *query_feats.shape[2:]) if self.out_project is not None: context = self.out_project(context) return context ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/models/utils/up_conv_block.py ================================================ import torch import torch.nn as nn from mmcv.cnn import ConvModule, build_upsample_layer class UpConvBlock(nn.Module): """Upsample convolution block in decoder for UNet. This upsample convolution block consists of one upsample module followed by one convolution block. The upsample module expands the high-level low-resolution feature map and the convolution block fuses the upsampled high-level low-resolution feature map and the low-level high-resolution feature map from encoder. Args: conv_block (nn.Sequential): Sequential of convolutional layers. in_channels (int): Number of input channels of the high-level skip_channels (int): Number of input channels of the low-level high-resolution feature map from encoder. out_channels (int): Number of output channels. num_convs (int): Number of convolutional layers in the conv_block. Default: 2. stride (int): Stride of convolutional layer in conv_block. Default: 1. dilation (int): Dilation rate of convolutional layer in conv_block. Default: 1. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. conv_cfg (dict | None): Config dict for convolution layer. Default: None. norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict | None): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). upsample_cfg (dict): The upsample config of the upsample module in decoder. Default: dict(type='InterpConv'). If the size of high-level feature map is the same as that of skip feature map (low-level feature map from encoder), it does not need upsample the high-level feature map and the upsample_cfg is None. dcn (bool): Use deformable convoluton in convolutional layer or not. Default: None. plugins (dict): plugins for convolutional layers. Default: None. """ def __init__(self, conv_block, in_channels, skip_channels, out_channels, num_convs=2, stride=1, dilation=1, with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='InterpConv'), dcn=None, plugins=None): super(UpConvBlock, self).__init__() assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' self.conv_block = conv_block( in_channels=2 * skip_channels, out_channels=out_channels, num_convs=num_convs, stride=stride, dilation=dilation, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, dcn=None, plugins=None) if upsample_cfg is not None: self.upsample = build_upsample_layer( cfg=upsample_cfg, in_channels=in_channels, out_channels=skip_channels, with_cp=with_cp, norm_cfg=norm_cfg, act_cfg=act_cfg) else: self.upsample = ConvModule( in_channels, skip_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) def forward(self, skip, x): """Forward function.""" x = self.upsample(x) out = torch.cat([skip, x], dim=1) out = self.conv_block(out) return out ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/ops/__init__.py ================================================ from .encoding import Encoding from .wrappers import Upsample, resize __all__ = ['Upsample', 'resize', 'Encoding'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/ops/encoding.py ================================================ import torch from torch import nn as nn from torch.nn import functional as F class Encoding(nn.Module): """Encoding Layer: a learnable residual encoder. Input is of shape (batch_size, channels, height, width). Output is of shape (batch_size, num_codes, channels). Args: channels: dimension of the features or feature channels num_codes: number of code words """ def __init__(self, channels, num_codes): super(Encoding, self).__init__() # init codewords and smoothing factor self.channels, self.num_codes = channels, num_codes std = 1. / ((num_codes * channels)**0.5) # [num_codes, channels] self.codewords = nn.Parameter( torch.empty(num_codes, channels, dtype=torch.float).uniform_(-std, std), requires_grad=True) # [num_codes] self.scale = nn.Parameter( torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True) @staticmethod def scaled_l2(x, codewords, scale): num_codes, channels = codewords.size() batch_size = x.size(0) reshaped_scale = scale.view((1, 1, num_codes)) expanded_x = x.unsqueeze(2).expand( (batch_size, x.size(1), num_codes, channels)) reshaped_codewords = codewords.view((1, 1, num_codes, channels)) scaled_l2_norm = reshaped_scale * ( expanded_x - reshaped_codewords).pow(2).sum(dim=3) return scaled_l2_norm @staticmethod def aggregate(assigment_weights, x, codewords): num_codes, channels = codewords.size() reshaped_codewords = codewords.view((1, 1, num_codes, channels)) batch_size = x.size(0) expanded_x = x.unsqueeze(2).expand( (batch_size, x.size(1), num_codes, channels)) encoded_feat = (assigment_weights.unsqueeze(3) * (expanded_x - reshaped_codewords)).sum(dim=1) return encoded_feat def forward(self, x): assert x.dim() == 4 and x.size(1) == self.channels # [batch_size, channels, height, width] batch_size = x.size(0) # [batch_size, height x width, channels] x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() # assignment_weights: [batch_size, channels, num_codes] assigment_weights = F.softmax( self.scaled_l2(x, self.codewords, self.scale), dim=2) # aggregate encoded_feat = self.aggregate(assigment_weights, x, self.codewords) return encoded_feat def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ f'x{self.channels})' return repr_str ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/ops/wrappers.py ================================================ import warnings import torch import torch.nn as nn import torch.nn.functional as F def resize(input, size=None, scale_factor=None, mode='nearest', align_corners=None, warning=True): if warning: if size is not None and align_corners: input_h, input_w = tuple(int(x) for x in input.shape[2:]) output_h, output_w = tuple(int(x) for x in size) if output_h > input_h or output_w > output_h: if ((output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) and (output_h - 1) % (input_h - 1) and (output_w - 1) % (input_w - 1)): warnings.warn( f'When align_corners={align_corners}, ' 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') if isinstance(size, torch.Size): size = tuple(int(x) for x in size) return F.interpolate(input, size, scale_factor, mode, align_corners) class Upsample(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): super(Upsample, self).__init__() self.size = size if isinstance(scale_factor, tuple): self.scale_factor = tuple(float(factor) for factor in scale_factor) else: self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode self.align_corners = align_corners def forward(self, x): if not self.size: size = [int(t * self.scale_factor) for t in x.shape[-2:]] else: size = self.size return resize(x, size, None, self.mode, self.align_corners) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/utils/__init__.py ================================================ from .collect_env import collect_env from .logger import get_root_logger __all__ = ['get_root_logger', 'collect_env'] ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/utils/collect_env.py ================================================ from mmcv.utils import collect_env as collect_base_env from mmcv.utils import get_git_hash import mmseg def collect_env(): """Collect the information of the running environments.""" env_info = collect_base_env() env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' return env_info if __name__ == '__main__': for name, val in collect_env().items(): print('{}: {}'.format(name, val)) ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/utils/logger.py ================================================ import logging from mmcv.utils import get_logger def get_root_logger(log_file=None, log_level=logging.INFO): """Get the root logger. The logger will be initialized if it has not been initialized. By default a StreamHandler will be added. If `log_file` is specified, a FileHandler will also be added. The name of the root logger is the top-level package name, e.g., "mmseg". Args: log_file (str | None): The log filename. If specified, a FileHandler will be added to the root logger. log_level (int): The root logger level. Note that only the process of rank 0 is affected, while other processes will set the level to "Error" and be silent most of the time. Returns: logging.Logger: The root logger. """ logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) return logger ================================================ FILE: downstream_tasks/semantic_segmentation/mmseg/version.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. __version__ = '0.11.0' def parse_version_info(version_str): version_info = [] for x in version_str.split('.'): if x.isdigit(): version_info.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') version_info.append(int(patch_version[0])) version_info.append(f'rc{patch_version[1]}') return tuple(version_info) version_info = parse_version_info(__version__) ================================================ FILE: downstream_tasks/semantic_segmentation/tools/dist_test.sh ================================================ #!/usr/bin/env bash CONFIG=$1 CHECKPOINT=$2 GPUS=$3 PORT=7956 OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=$GPUS \ --master_port=$PORT \ tools/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} ================================================ FILE: downstream_tasks/semantic_segmentation/tools/dist_train.sh ================================================ #!/usr/bin/env bash CONFIG=$1 GPUS=$2 PORT=7956 OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=$GPUS \ --master_port=$PORT \ tools/train.py $CONFIG --launcher pytorch ${@:3} \ ================================================ FILE: downstream_tasks/semantic_segmentation/tools/test.py ================================================ import argparse import os import mmcv import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint from mmcv.utils import DictAction from mmseg.apis import multi_gpu_test, single_gpu_test from mmseg.datasets import build_dataloader, build_dataset from mmseg.models import build_segmentor from backbone import beit from backbone import cae def parse_args(): parser = argparse.ArgumentParser( description='mmseg test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( '--aug-test', action='store_true', help='Use Flip and Multi scale aug') parser.add_argument('--out', help='output result file in pickle format') parser.add_argument( '--format-only', action='store_true', help='Format the output results without perform evaluation. It is' 'useful when you want to format the result to a specific format and ' 'submit it to the test server') parser.add_argument( '--eval', type=str, nargs='+', help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' ' for generic datasets, and "cityscapes" for Cityscapes') parser.add_argument('--show', action='store_true', help='show results') parser.add_argument( '--show-dir', help='directory where painted images will be saved') parser.add_argument( '--gpu-collect', action='store_true', help='whether to use gpu to collect results.') parser.add_argument( '--tmpdir', help='tmp directory used for collecting results from multiple ' 'workers, available when gpu_collect is not specified') parser.add_argument( '--options', nargs='+', action=DictAction, help='custom options') parser.add_argument( '--eval-options', nargs='+', action=DictAction, help='custom options for evaluation') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args def main(): args = parse_args() assert args.out or args.eval or args.format_only or args.show \ or args.show_dir, \ ('Please specify at least one operation (save/eval/format/show the ' 'results / save the results) with the argument "--out", "--eval"' ', "--format-only", "--show" or "--show-dir"') if args.eval and args.format_only: raise ValueError('--eval and --format_only cannot be both specified') if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): raise ValueError('The output file must be a pkl file.') cfg = mmcv.Config.fromfile(args.config) if args.options is not None: cfg.merge_from_dict(args.options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True if args.aug_test: # hard code index cfg.data.test.pipeline[1].img_ratios = [ 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 ] cfg.data.test.pipeline[1].flip = True cfg.model.pretrained = None cfg.data.test.test_mode = True # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) dataset = build_dataset(cfg.data.test) data_loader = build_dataloader( dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) # build the model and load checkpoint cfg.model.train_cfg = None model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') model.CLASSES = checkpoint['meta']['CLASSES'] model.PALETTE = checkpoint['meta']['PALETTE'] efficient_test = False if args.eval_options is not None: efficient_test = args.eval_options.get('efficient_test', False) if not distributed: model = MMDataParallel(model, device_ids=[0]) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, efficient_test) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect, efficient_test) rank, _ = get_dist_info() if rank == 0: if args.out: print(f'\nwriting results to {args.out}') mmcv.dump(outputs, args.out) kwargs = {} if args.eval_options is None else args.eval_options if args.format_only: dataset.format_results(outputs, **kwargs) if args.eval: dataset.evaluate(outputs, args.eval, **kwargs) if __name__ == '__main__': main() ================================================ FILE: downstream_tasks/semantic_segmentation/tools/train.py ================================================ import argparse import copy import os import os.path as osp import time import mmcv import mmcv_custom import torch from mmcv.runner import init_dist from mmcv.utils import Config, DictAction, get_git_hash from mmseg import __version__ from mmseg.apis import set_random_seed from mmcv_custom import train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor from mmseg.utils import collect_env, get_root_logger from backbone import beit from backbone import mae from backbone import beit_fapn from backbone import cae def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--load-from', help='the checkpoint file to load weights from') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') parser.add_argument( '--no-validate', action='store_true', help='whether not to evaluate the checkpoint during training') group_gpus = parser.add_mutually_exclusive_group() group_gpus.add_argument( '--gpus', type=int, help='number of gpus to use ' '(only applicable to non-distributed training)') group_gpus.add_argument( '--gpu-ids', type=int, nargs='+', help='ids of gpus to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='random seed') parser.add_argument( '--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') parser.add_argument( '--options', nargs='+', action=DictAction, help='custom options') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.options is not None: cfg.merge_from_dict(args.options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) if args.load_from is not None: cfg.load_from = args.load_from if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info # log some basic info logger.info(f'Distributed training: {distributed}') logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds if args.seed is not None: logger.info(f'Set random seed to {args.seed}, deterministic: ' f'{args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed meta['exp_name'] = osp.basename(args.config) model = build_segmentor( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) logger.info(model) datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmseg version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmseg_version=f'{__version__}+{get_git_hash()[:7]}', config=cfg.pretty_text, CLASSES=datasets[0].CLASSES, PALETTE=datasets[0].PALETTE) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES train_segmentor( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta) if __name__ == '__main__': main() ================================================ FILE: furnace/dataset_folder.py ================================================ from torchvision.datasets.vision import VisionDataset from PIL import Image import os import os.path import random from typing import Any, Callable, cast, Dict, List, Optional, Tuple def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename: str) -> bool: """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset( directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: instances = [] directory = os.path.expanduser(directory) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index instances.append(item) return instances class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 files in subfolders of: {}\n".format(self.root) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ while True: try: path, target = self.samples[index] sample = self.loader(path) break except Exception as e: print(e) index = random.randint(0, len(self.samples) - 1) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self) -> int: return len(self.samples) IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path: str) -> Any: from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples ================================================ FILE: furnace/datasets.py ================================================ import os import torch from torchvision import datasets, transforms from timm.data.constants import \ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from furnace.transforms import RandomResizedCropAndInterpolationWithTwoPic from timm.data import create_transform from dall_e.utils import map_pixels from furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator from furnace.dataset_folder import ImageFolder def preprocess_vqgan(x): x = 2.*x - 1. return x class DataAugmentationForCAE(object): def __init__(self, args): imagenet_default_mean_and_std = args.imagenet_default_mean_and_std mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD if args.color_jitter > 0: self.common_transform = transforms.Compose([ transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter), transforms.RandomHorizontalFlip(p=0.5), RandomResizedCropAndInterpolationWithTwoPic( size=args.input_size, second_size=args.second_input_size, interpolation=args.train_interpolation, second_interpolation=args.second_interpolation, scale=(args.crop_min_size, args.crop_max_size), ), ]) else: self.common_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), RandomResizedCropAndInterpolationWithTwoPic( size=args.input_size, second_size=args.second_input_size, interpolation=args.train_interpolation, second_interpolation=args.second_interpolation, scale=(args.crop_min_size, args.crop_max_size), ), ]) self.patch_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ]) if args.discrete_vae_type == "dall-e": self.visual_token_transform = transforms.Compose([ transforms.ToTensor(), map_pixels, ]) elif args.discrete_vae_type == "vqgan_gumbel_f8_8192": self.visual_token_transform = transforms.Compose([ transforms.ToTensor(), preprocess_vqgan, ]) elif args.discrete_vae_type == "customized": self.visual_token_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, ), ]) else: raise NotImplementedError() if args.mask_generator == 'block': self.masked_position_generator = MaskingGenerator( args.window_size, num_masking_patches=args.num_mask_patches, max_num_patches=args.max_mask_patches_per_block, min_num_patches=args.min_mask_patches_per_block, ) elif args.mask_generator == 'random': self.masked_position_generator = RandomMaskingGenerator( args.window_size, ratio_masking_patches=args.ratio_mask_patches ) def __call__(self, image): for_patches, for_visual_tokens = self.common_transform(image) return \ self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \ self.masked_position_generator() def __repr__(self): repr = "(DataAugmentationForCAE,\n" repr += " common_transform = %s,\n" % str(self.common_transform) repr += " patch_transform = %s,\n" % str(self.patch_transform) repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform) repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) repr += ")" return repr def build_cae_pretraining_dataset(args): transform = DataAugmentationForCAE(args) print("Data Aug = %s" % str(transform)) return ImageFolder(args.data_path, transform=transform) def build_dataset(is_train, args): transform = build_transform(is_train, args) print("Transform = ") if isinstance(transform, tuple): for trans in transform: print(" - - - - - - - - - - ") for t in trans.transforms: print(t) else: for t in transform.transforms: print(t) print("---------------------------") if args.data_set == 'CIFAR': dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) nb_classes = 100 elif args.data_set == 'IMNET': root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 elif args.data_set == "image_folder": root = args.data_path if is_train else args.eval_data_path dataset = ImageFolder(root, transform=transform) nb_classes = args.nb_classes assert len(dataset.class_to_idx) == nb_classes else: raise NotImplementedError() assert nb_classes == args.nb_classes print("Number of the class = %d" % args.nb_classes) return dataset, nb_classes def build_transform(is_train, args): resize_im = args.input_size > 32 imagenet_default_mean_and_std = args.imagenet_default_mean_and_std mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop( args.input_size, padding=4) return transform t = [] if resize_im: if args.crop_pct is None: if args.input_size < 384: args.crop_pct = 224 / 256 else: args.crop_pct = 1.0 size = int(args.input_size / args.crop_pct) t.append( transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) ================================================ FILE: furnace/engine_for_finetuning.py ================================================ import math import sys import time from typing import Iterable, Optional import torch from timm.data import Mixup from timm.utils import accuracy, ModelEma import furnace.utils as utils def train_class_batch(model, samples, target, criterion): outputs = model(samples) loss = criterion(outputs, target) return loss, outputs def get_loss_scale_for_deepspeed(model): optimizer = model.optimizer return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, num_training_steps_per_epoch=None, update_freq=None): model.train(True) metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 if loss_scaler is None: model.zero_grad() model.micro_steps = 0 else: optimizer.zero_grad() for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): step = data_iter_step // update_freq if step >= num_training_steps_per_epoch: continue it = start_steps + step # global training iteration # Update LR & WD for the first acc if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: for i, param_group in enumerate(optimizer.param_groups): if lr_schedule_values is not None: if "lr_scale" in param_group: param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] else: param_group["lr"] = lr_schedule_values[it] if wd_schedule_values is not None and param_group["weight_decay"] > 0: param_group["weight_decay"] = wd_schedule_values[it] samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if loss_scaler is None: samples = samples.half() loss, output = train_class_batch( model, samples, targets, criterion) else: with torch.cuda.amp.autocast(): loss, output = train_class_batch( model, samples, targets, criterion) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) if loss_scaler is None: loss /= update_freq model.backward(loss) model.step() if (data_iter_step + 1) % update_freq == 0: # model.zero_grad() # Deepspeed will call step() & model.zero_grad() automatic if model_ema is not None: model_ema.update(model) grad_norm = None loss_scale_value = get_loss_scale_for_deepspeed(model) else: # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order loss /= update_freq grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order, update_grad=(data_iter_step + 1) % update_freq == 0) if (data_iter_step + 1) % update_freq == 0: optimizer.zero_grad() if model_ema is not None: model_ema.update(model) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() if mixup_fn is None: class_acc = (output.max(-1)[-1] == targets).float().mean() else: class_acc = None metric_logger.update(loss=loss_value) metric_logger.update(class_acc=class_acc) metric_logger.update(loss_scale=loss_scale_value) min_lr = 10. max_lr = 0. for group in optimizer.param_groups: min_lr = min(min_lr, group["lr"]) max_lr = max(max_lr, group["lr"]) metric_logger.update(lr=max_lr) metric_logger.update(min_lr=min_lr) weight_decay_value = None for group in optimizer.param_groups: if group["weight_decay"] > 0: weight_decay_value = group["weight_decay"] metric_logger.update(weight_decay=weight_decay_value) metric_logger.update(grad_norm=grad_norm) if log_writer is not None: log_writer.update(loss=loss_value, head="loss") log_writer.update(class_acc=class_acc, head="loss") log_writer.update(loss_scale=loss_scale_value, head="opt") log_writer.update(lr=max_lr, head="opt") log_writer.update(min_lr=min_lr, head="opt") log_writer.update(weight_decay=weight_decay_value, head="opt") log_writer.update(grad_norm=grad_norm, head="opt") log_writer.set_step() # gather the stats from all processes metric_logger.synchronize_between_processes() now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) print(now_time, "Averaged stats:", metric_logger) # print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode model.eval() for batch in metric_logger.log_every(data_loader, 10, header): images = batch[0] target = batch[-1] images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output with torch.cuda.amp.autocast(): output = model(images) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} ================================================ FILE: furnace/engine_for_pretraining.py ================================================ import math import sys import time from typing import Iterable import torch import torch.nn as nn import furnace.utils as utils import torch.nn.functional as F from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def loss_selector(loss_type, pred, target): if loss_type == 'mse': return F.mse_loss(pred, target, reduction="mean") elif loss_type == 'kld': return F.kl_div(F.log_softmax(pred, dim=-1), F.softmax(target, dim=-1), reduction='mean') def train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, log_writer=None, lr_scheduler=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, args=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # assign learning rate & weight decay for each step it = start_steps + step # global training iteration if lr_schedule_values is not None or wd_schedule_values is not None: for i, param_group in enumerate(optimizer.param_groups): if lr_schedule_values is not None: param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] if wd_schedule_values is not None and param_group["weight_decay"] > 0: param_group["weight_decay"] = wd_schedule_values[it] samples, images, bool_masked_pos = batch images = images.to(device, non_blocking=True) samples = samples.to(device, non_blocking=True) bool_masked_pos = bool_masked_pos.to(device, non_blocking=True) with torch.no_grad(): input_ids = d_vae.get_codebook_indices(images).flatten(1) bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool) labels = input_ids[bool_masked_pos] with torch.cuda.amp.autocast(): outputs, latent, latent_target = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False) loss_main = nn.CrossEntropyLoss()(input=outputs.float(), target=labels) loss_align = args.align_loss_weight * loss_selector('mse', latent.float(), latent_target.detach().float()) loss = loss_main + loss_align loss_value = loss.item() loss_main_value = loss_main.item() loss_align_value = loss_align.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item() metric_logger.update(mlm_acc=mlm_acc) if log_writer is not None: log_writer.update(mlm_acc=mlm_acc, head="loss") metric_logger.update(loss=loss_value) metric_logger.update(loss_main=loss_main_value) metric_logger.update(loss_align=loss_align_value) metric_logger.update(loss_scale=loss_scale_value) min_lr = 10. max_lr = 0. for group in optimizer.param_groups: min_lr = min(min_lr, group["lr"]) max_lr = max(max_lr, group["lr"]) metric_logger.update(lr=max_lr) metric_logger.update(min_lr=min_lr) weight_decay_value = None for group in optimizer.param_groups: if group["weight_decay"] > 0: weight_decay_value = group["weight_decay"] metric_logger.update(weight_decay=weight_decay_value) metric_logger.update(grad_norm=grad_norm) if log_writer is not None: log_writer.update(loss=loss_value, head="loss") log_writer.update(loss=loss_main_value, head="loss_main") log_writer.update(loss=loss_align_value, head="loss_align") log_writer.update(loss_scale=loss_scale_value, head="opt") log_writer.update(lr=max_lr, head="opt") log_writer.update(min_lr=min_lr, head="opt") log_writer.update(weight_decay=weight_decay_value, head="opt") log_writer.update(grad_norm=grad_norm, head="opt") log_writer.set_step() if lr_scheduler is not None: lr_scheduler.step_update(start_steps + step) # gather the stats from all processes metric_logger.synchronize_between_processes() now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) print(now_time, "Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} ================================================ FILE: furnace/masking_generator.py ================================================ import random import math import numpy as np class MaskingGenerator: def __init__( self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, min_aspect=0.3, max_aspect=None): if not isinstance(input_size, tuple): input_size = (input_size, ) * 2 self.height, self.width = input_size self.num_patches = self.height * self.width self.num_masking_patches = num_masking_patches self.min_num_patches = min_num_patches self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) def __repr__(self): repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( self.height, self.width, self.min_num_patches, self.max_num_patches, self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) return repr_str def get_shape(self): return self.height, self.width def _mask(self, mask, max_mask_patches): delta = 0 for attempt in range(10): target_area = random.uniform(self.min_num_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < self.width and h < self.height: top = random.randint(0, self.height - h) left = random.randint(0, self.width - w) num_masked = mask[top: top + h, left: left + w].sum() # Overlap if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break return delta def __call__(self): mask = np.zeros(shape=self.get_shape(), dtype=int) mask_count = 0 while mask_count != self.num_masking_patches: max_mask_patches = self.num_masking_patches - mask_count max_mask_patches = min(max_mask_patches, self.max_num_patches) delta = self._mask(mask, max_mask_patches) mask_count += delta return mask class RandomMaskingGenerator: def __init__( self, input_size, ratio_masking_patches): if not isinstance(input_size, tuple): input_size = (input_size, ) * 2 self.height, self.width = input_size self.num_patches = self.height * self.width self.num_masking_patches = int(ratio_masking_patches * self.num_patches) def __repr__(self): repr_str = "Maks: total patches {}, mask patches {}".format( self.num_patches, self.num_masking_patches ) return repr_str def __call__(self): mask = np.hstack([ np.zeros(self.num_patches - self.num_masking_patches), np.ones(self.num_masking_patches), ]) np.random.shuffle(mask) return mask ================================================ FILE: furnace/optim_factory.py ================================================ import torch from torch import optim as optim from timm.optim.adafactor import Adafactor from timm.optim.adahessian import Adahessian from timm.optim.adamp import AdamP from timm.optim.lookahead import Lookahead from timm.optim.nadam import Nadam from timm.optim.novograd import NovoGrad from timm.optim.nvnovograd import NvNovoGrad from timm.optim.radam import RAdam from timm.optim.rmsprop_tf import RMSpropTF from timm.optim.sgdp import SGDP import json try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD has_apex = True except ImportError: has_apex = False def get_num_layer_for_vit(var_name, num_max_layer): if var_name in ("cls_token", "mask_token", "pos_embed"): return 0 elif var_name.startswith("patch_embed"): return 0 elif var_name.startswith("rel_pos_bias"): return num_max_layer - 1 elif var_name.startswith("blocks"): layer_id = int(var_name.split('.')[1]) return layer_id + 1 else: return num_max_layer - 1 class LayerDecayValueAssigner(object): def __init__(self, values): self.values = values def get_scale(self, layer_id): return self.values[layer_id] def get_layer_id(self, var_name): return get_num_layer_for_vit(var_name, len(self.values)) def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): parameter_group_names = {} parameter_group_vars = {} for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: group_name = "no_decay" this_weight_decay = 0. else: group_name = "decay" this_weight_decay = weight_decay if get_num_layer is not None: layer_id = get_num_layer(name) group_name = "layer_%d_%s" % (layer_id, group_name) else: layer_id = None if group_name not in parameter_group_names: if get_layer_scale is not None: scale = get_layer_scale(layer_id) else: scale = 1. parameter_group_names[group_name] = { "weight_decay": this_weight_decay, "params": [], "lr_scale": scale } parameter_group_vars[group_name] = { "weight_decay": this_weight_decay, "params": [], "lr_scale": scale } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) return list(parameter_group_vars.values()) def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: skip = {} if skip_list is not None: skip = skip_list elif hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not args.lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer ================================================ FILE: furnace/transforms.py ================================================ import torch import torchvision.transforms.functional as F from PIL import Image import warnings import math import random import numpy as np class ToNumpy: def __call__(self, pil_img): np_img = np.array(pil_img, dtype=np.uint8) if np_img.ndim < 3: np_img = np.expand_dims(np_img, axis=-1) np_img = np.rollaxis(np_img, 2) # HWC to CHW return np_img class ToTensor: def __init__(self, dtype=torch.float32): self.dtype = dtype def __call__(self, pil_img): np_img = np.array(pil_img, dtype=np.uint8) if np_img.ndim < 3: np_img = np.expand_dims(np_img, axis=-1) np_img = np.rollaxis(np_img, 2) # HWC to CHW return torch.from_numpy(np_img).to(dtype=self.dtype) _pil_interpolation_to_str = { Image.NEAREST: 'PIL.Image.NEAREST', Image.BILINEAR: 'PIL.Image.BILINEAR', Image.BICUBIC: 'PIL.Image.BICUBIC', Image.LANCZOS: 'PIL.Image.LANCZOS', Image.HAMMING: 'PIL.Image.HAMMING', Image.BOX: 'PIL.Image.BOX', } def _pil_interp(method): if method == 'bicubic': return Image.BICUBIC elif method == 'lanczos': return Image.LANCZOS elif method == 'hamming': return Image.HAMMING else: # default bilinear, do we want to allow nearest? return Image.BILINEAR _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) class RandomResizedCropAndInterpolationWithTwoPic: """Crop the given PIL Image to random size and aspect ratio with random interpolation. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge scale: range of size of the origin size cropped ratio: range of aspect ratio of the origin aspect ratio cropped interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation='bilinear', second_interpolation='lanczos'): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if second_size is not None: if isinstance(second_size, tuple): self.second_size = second_size else: self.second_size = (second_size, second_size) else: self.second_size = None if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("range should be of kind (min, max)") if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = _pil_interp(interpolation) self.second_interpolation = _pil_interp(second_interpolation) self.scale = scale self.ratio = ratio @staticmethod def get_params(img, scale, ratio): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. scale (tuple): range of size of the origin size cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ area = img.size[0] * img.size[1] for attempt in range(10): target_area = random.uniform(*scale) * area log_ratio = (math.log(ratio[0]), math.log(ratio[1])) aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback to central crop in_ratio = img.size[0] / img.size[1] if in_ratio < min(ratio): w = img.size[0] h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = img.size[1] w = int(round(h * max(ratio))) else: # whole image w = img.size[0] h = img.size[1] i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 return i, j, h, w def __call__(self, img): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) if isinstance(self.interpolation, (tuple, list)): interpolation = random.choice(self.interpolation) else: interpolation = self.interpolation if self.second_size is None: return F.resized_crop(img, i, j, h, w, self.size, interpolation) else: return F.resized_crop(img, i, j, h, w, self.size, interpolation), \ F.resized_crop(img, i, j, h, w, self.second_size, self.second_interpolation) def __repr__(self): if isinstance(self.interpolation, (tuple, list)): interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) else: interpolate_str = _pil_interpolation_to_str[self.interpolation] format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', interpolation={0}'.format(interpolate_str) if self.second_size is not None: format_string += ', second_size={0}'.format(self.second_size) format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation]) format_string += ')' return format_string ================================================ FILE: furnace/utils.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, DINO and DeiT code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit # https://github.com/facebookresearch/dino # --------------------------------------------------------' import io import os import math import time import json from collections import defaultdict, deque import datetime import numpy as np from timm.utils import get_state_dict from pathlib import Path import torch import torch.distributed as dist from torch._six import inf from models.modeling_discrete_vae import Dalle_VAE, DiscreteVAE, VGGAN from tensorboardX import SummaryWriter import torch.nn.functional as F from torch.nn.modules.batchnorm import _NormBase class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) class TensorboardLogger(object): def __init__(self, log_dir): self.writer = SummaryWriter(logdir=log_dir) self.step = 0 def set_step(self, step=None): if step is not None: self.step = step else: self.step += 1 def update(self, head='scalar', step=None, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) def flush(self): self.writer.flush() def _load_checkpoint_for_ema(model_ema, checkpoint): """ Workaround for ModelEma._load_checkpoint to accept an already-loaded object """ mem_file = io.BytesIO() torch.save(checkpoint, mem_file) mem_file.seek(0) model_ema._load_checkpoint(mem_file) def setup_for_distributed_each_gpu(rank): import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): builtin_print('rank is: ', rank, end=' ') now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) __builtin__.print = print def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if args.dist_on_itp: args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() if not args.enable_multi_print: setup_for_distributed(args.rank == 0) else: setup_for_distributed_each_gpu(args.rank) def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix=prefix) warn_missing_keys = [] ignore_missing_keys = [] for key in missing_keys: keep_flag = True for ignore_key in ignore_missing.split('|'): if ignore_key in key: keep_flag = False break if keep_flag: warn_missing_keys.append(key) else: ignore_missing_keys.append(key) missing_keys = warn_missing_keys if len(missing_keys) > 0: print("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: print("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(ignore_missing_keys) > 0: print("Ignored weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, ignore_missing_keys)) if len(error_msgs) > 0: print('\n'.join(error_msgs)) class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0, warmup_steps=-1): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_steps > 0: warmup_iters = warmup_steps print("Set warmup steps = %d" % warmup_iters) if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = np.array( [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, exp_name=None): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: if exp_name is not None: checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))] else: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] for checkpoint_path in checkpoint_paths: to_save_state_dict = model_without_ddp.state_dict() # all_keys = list(state_dict.keys()) for key in list(to_save_state_dict.keys()): if key.startswith('teacher.'): to_save_state_dict.pop(key) to_save = { 'model': to_save_state_dict, 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } if model_ema is not None: to_save['model_ema'] = get_state_dict(model_ema) save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} if model_ema is not None: client_state['model_ema'] = get_state_dict(model_ema) if exp_name is not None: model.save_checkpoint(save_dir=args.output_dir, tag="{}_checkpoint-{}".format(exp_name, epoch_name), client_state=client_state) else: model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): output_dir = Path(args.output_dir) if loss_scaler is not None: # torch.amp if args.auto_resume and len(args.resume) == 0: import glob all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) latest_ckpt = -1 for ckpt in all_checkpoints: t = ckpt.split('-')[-1].split('.')[0] if t.isdigit(): latest_ckpt = max(int(t), latest_ckpt) if latest_ckpt >= 0: args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) print("Auto resume checkpoint: %s" % args.resume) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') # handle ema model need_state_dict = model_without_ddp.state_dict() need_ema = False for key in need_state_dict.keys(): if 'teacher' in key: need_ema = True break checkpoint_model = checkpoint['model'] if need_ema: all_keys = list(checkpoint_model.keys()) all_keys = [key for key in all_keys if key.startswith('encoder.')] for key in all_keys: new_key = key.replace('encoder.','teacher.') checkpoint_model[new_key] = checkpoint_model[key] model_without_ddp.load_state_dict(checkpoint_model) print("Resume checkpoint %s" % args.resume) if 'optimizer' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if hasattr(args, 'model_ema') and args.model_ema: _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") else: # deepspeed, only support '--auto_resume'. if args.auto_resume: import glob all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) latest_ckpt = -1 for ckpt in all_checkpoints: t = ckpt.split('-')[-1].split('.')[0] if t.isdigit(): latest_ckpt = max(int(t), latest_ckpt) if latest_ckpt >= 0: args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) print("Auto resume checkpoint: %d" % latest_ckpt) _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) args.start_epoch = client_states['epoch'] + 1 if model_ema is not None: if args.model_ema: _load_checkpoint_for_ema(model_ema, client_states['model_ema']) def create_d_vae(weight_path, d_vae_type, image_size, device, args=None): if d_vae_type == "dall-e": return get_dalle_vae(weight_path, image_size, device) if d_vae_type == "vqgan_gumbel_f8_8192": return get_vqgan_gumbel_f8_8192(weight_path, image_size, device) elif d_vae_type == "customized": return get_d_vae(weight_path, image_size, device, args) elif d_vae_type == "to_tensor": return None else: raise NotImplementedError() def get_vqgan_gumbel_f8_8192(weight_path, image_size, device): with torch.no_grad(): vqgan = VGGAN(image_size) vqgan.load_model(weight_path, device) return vqgan def get_dalle_vae(weight_path, image_size, device): vae = Dalle_VAE(image_size) vae.load_model(model_dir=weight_path, device=device) return vae def get_d_vae(weight_path, image_size, device, args): NUM_TOKENS = 8192 NUM_LAYERS = args.dvae_num_layers EMB_DIM = 512 HID_DIM = 256 state_dict = torch.load(weight_path, map_location="cpu")["model"] model = DiscreteVAE( image_size=image_size, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, ).to(device) model.load_state_dict(state_dict) return model def create_ds_config(args): args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") with open(args.deepspeed_config, mode="w") as writer: ds_config = { "train_batch_size": args.batch_size * args.update_freq * get_world_size(), "train_micro_batch_size_per_gpu": args.batch_size, "steps_per_print": 1000, "optimizer": { "type": "Adam", "adam_w_mode": True, "params": { "lr": args.lr, "weight_decay": args.weight_decay, "bias_correction": True, "betas": [ 0.9, 0.999 ], "eps": 1e-8 } }, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 7, "loss_scale_window": 128 } } writer.write(json.dumps(ds_config, indent=2)) class LP_BatchNorm(_NormBase): """ A variant used in linear probing. To freeze parameters (normalization operator specifically), model set to eval mode during linear probing. According to paper, an extra BN is used on the top of encoder to calibrate the feature magnitudes. In addition to self.training, we set another flag in this implement to control BN's behavior to train in eval mode. """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(LP_BatchNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) def forward(self, input, is_train): """ We use is_train instead of self.training. """ self._check_input_dim(input) # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum # if self.training and self.track_running_stats: if is_train and self.track_running_stats: if self.num_batches_tracked is not None: # type: ignore self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum r""" Decide whether the mini-batch stats should be used for normalization rather than the buffers. Mini-batch stats are used in training mode, and in eval mode when buffers are None. """ if is_train: bn_training = True else: bn_training = (self.running_mean is None) and (self.running_var is None) r""" Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) assert self.running_var is None or isinstance(self.running_var, torch.Tensor) return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean if not is_train or self.track_running_stats else None, self.running_var if not is_train or self.track_running_stats else None, self.weight, self.bias, bn_training, exponential_average_factor, self.eps) ================================================ FILE: linear_util/crop.py ================================================ import math import torch from torchvision import transforms from torchvision.transforms import functional as F class RandomResizedCrop(transforms.RandomResizedCrop): """ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. This may lead to results different with torchvision's version. Following BYOL's TF code: https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 """ @staticmethod def get_params(img, scale, ratio): width, height = F._get_image_size(img) area = height * width target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() log_ratio = torch.log(torch.tensor(ratio)) aspect_ratio = torch.exp( torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) ).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) w = min(w, width) h = min(h, height) i = torch.randint(0, height - h + 1, size=(1,)).item() j = torch.randint(0, width - w + 1, size=(1,)).item() return i, j, h, w ================================================ FILE: linear_util/datasets.py ================================================ import os import PIL from torchvision import datasets, transforms from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from furnace.masking_generator import MaskingGenerator, RandomMaskingGenerator class DataAugmentationMySelf(object): def __init__(self, args): self.patch_transform = transforms.Compose([ transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if args.mask_generator == 'block': self.masked_position_generator = MaskingGenerator( args.window_size, num_masking_ratio=args.mask_ratio, max_num_patches=args.max_mask_patches_per_block, min_num_patches=args.min_mask_patches_per_block, ) elif args.mask_generator == 'random': self.masked_position_generator = RandomMaskingGenerator( args.window_size, ratio_masking_patches=args.mask_ratio ) def __call__(self, image): return self.patch_transform(image), self.masked_position_generator() def __repr__(self): repr = "(DataAugmentationMySelf,\n" repr += " patch_transform = %s,\n" % str(self.patch_transform) repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) repr += ")" return repr def build_dataset(is_train, args): transform = build_transform(is_train, args) root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) print(dataset) return dataset def build_dataset_finetune(is_train, args): transform = build_transform_finetune(is_train, args) root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) print(dataset) return dataset def build_transform_finetune(is_train, args): mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD # train transform if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='bicubic', re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) return transform # eval transform t = [] if args.input_size <= 224: crop_pct = 224 / 256 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) t.append( transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) def build_transform(is_train, args): mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD # train transform if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='bicubic', re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) return DataAugmentationMySelf(args, transform) # eval transform t = [] if args.input_size <= 224: crop_pct = 224 / 256 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) t.append( transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) ================================================ FILE: linear_util/engine_finetune.py ================================================ import math import sys from typing import Iterable, Optional import torch from timm.data import Mixup from timm.utils import accuracy import linear_util.misc as misc import linear_util.lr_sched as lr_sched def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, mixup_fn: Optional[Mixup] = None, log_writer=None, args=None): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 20 accum_iter = args.accum_iter optimizer.zero_grad() if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) with torch.cuda.amp.autocast(): outputs = model(samples) loss = criterion(outputs, targets) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss /= accum_iter loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False, update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) min_lr = 10. max_lr = 0. for group in optimizer.param_groups: min_lr = min(min_lr, group["lr"]) max_lr = max(max_lr, group["lr"]) metric_logger.update(lr=max_lr) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', max_lr, epoch_1000x) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() metric_logger = misc.MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode model.eval() for batch in metric_logger.log_every(data_loader, 10, header): images = batch[0] target = batch[-1] images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output with torch.cuda.amp.autocast(): output = model(images) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} ================================================ FILE: linear_util/lars.py ================================================ import torch class LARS(torch.optim.Optimizer): """ LARS optimizer, no rate scaling or weight decay for parameters <= 1D. """ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) super().__init__(params, defaults) @torch.no_grad() def step(self): for g in self.param_groups: for p in g['params']: dp = p.grad if dp is None: continue if p.ndim > 1: # if not normalization gamma/beta or bias dp = dp.add(p, alpha=g['weight_decay']) param_norm = torch.norm(p) update_norm = torch.norm(dp) one = torch.ones_like(param_norm) q = torch.where(param_norm > 0., torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one), one) dp = dp.mul(q) param_state = self.state[p] if 'mu' not in param_state: param_state['mu'] = torch.zeros_like(p) mu = param_state['mu'] mu.mul_(g['momentum']).add_(dp) p.add_(mu, alpha=-g['lr']) ================================================ FILE: linear_util/lr_decay.py ================================================ import json def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): param_group_names = {} param_groups = {} num_layers = len(model.blocks) + 1 layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) for n, p in model.named_parameters(): if not p.requires_grad: continue # no decay: all 1D parameters and model specific ones if p.ndim == 1 or n in no_weight_decay_list: g_decay = "no_decay" this_decay = 0. else: g_decay = "decay" this_decay = weight_decay layer_id = get_layer_id_for_vit(n, num_layers) group_name = "layer_%d_%s" % (layer_id, g_decay) if group_name not in param_group_names: this_scale = layer_scales[layer_id] param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "params": [], } param_groups[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "params": [], } param_group_names[group_name]["params"].append(n) param_groups[group_name]["params"].append(p) # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) return list(param_groups.values()) def get_layer_id_for_vit(name, num_layers): if name in ['cls_token', 'pos_embed']: return 0 elif name.startswith('patch_embed'): return 0 elif name.startswith('blocks'): return int(name.split('.')[1]) + 1 else: return num_layers ================================================ FILE: linear_util/lr_sched.py ================================================ import math def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < args.warmup_epochs: lr = args.lr * epoch / args.warmup_epochs else: lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] else: param_group["lr"] = lr return lr ================================================ FILE: linear_util/misc.py ================================================ import builtins import datetime import os import time from collections import defaultdict, deque from pathlib import Path import torch import torch.distributed as dist from torch._six import inf class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop('force', False) force = force or (get_world_size() > 8) if is_master or force: now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if args.dist_on_itp: args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') setup_for_distributed(is_master=True) # hack args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, exp_name=None): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: if exp_name is not None: checkpoint_paths = [output_dir / ('{}_checkpoint-{}.pth'.format(exp_name, epoch_name))] else: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} if exp_name is not None: model.save_checkpoint(save_dir=args.output_dir, tag="{}_checkpoint-{}".format(exp_name, epoch_name), client_state=client_state) else: model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def load_model(args, model_without_ddp, optimizer, loss_scaler): if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') checkpoint_model = checkpoint['model'] model_without_ddp.load_state_dict(checkpoint_model) print("Resume checkpoint %s" % args.resume) if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") def all_reduce_mean(x): world_size = get_world_size() if world_size > 1: x_reduce = torch.tensor(x).cuda() dist.all_reduce(x_reduce) x_reduce /= world_size return x_reduce.item() else: return x ================================================ FILE: linear_util/pos_embed.py ================================================ import numpy as np import torch def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed ================================================ FILE: models/modeling_cae.py ================================================ import math import time import torch import torch.nn as nn from functools import partial from models.modeling_finetune import _cfg, PatchEmbed from timm.models.registry import register_model from timm.models.layers import trunc_normal_ as __call_trunc_normal_ from models.modeling_cae_helper import * def trunc_normal_(tensor, mean=0., std=1.): __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) class VisionTransformerForMaskedImageModeling(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None, use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs): super().__init__() self.encoder = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans, vocab_size=vocab_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim, use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args) # alignment constraint self.teacher = VisionTransformerEncoder(img_size=img_size, patch_size=patch_size, in_chans=in_chans, vocab_size=vocab_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, init_values=init_values, attn_head_dim=attn_head_dim, use_abs_pos_emb=use_abs_pos_emb, init_std=init_std, args=args) self.init_std = init_std self.args = args self.num_patches = self.encoder.patch_embed.num_patches self.pretext_neck = VisionTransformerNeck(patch_size=patch_size, num_classes=args.decoder_num_classes, embed_dim=args.decoder_embed_dim, depth=args.regressor_depth, num_heads=args.decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, init_values=args.decoder_layer_scale_init_value, num_patches=self.num_patches, init_std=init_std, args=args) # encoder to decoder projection, borrowed from mae. if args.decoder_embed_dim != embed_dim: self.encoder_to_decoder = nn.Linear(embed_dim, args.decoder_embed_dim, bias=True) self.encoder_to_decoder_norm = norm_layer(args.decoder_embed_dim) else: self.encoder_to_decoder = None self.mask_token = nn.Parameter(torch.zeros(1, 1, args.decoder_embed_dim)) trunc_normal_(self.mask_token, std=self.init_std) ### whether to use 'rescale' to init the weight, borrowed from beit. if not args.fix_init_weight: self.apply(self._init_weights) self._init_teacher() def _init_teacher(self): # init the weights of teacher with those of backbone for param_encoder, param_teacher in zip(self.encoder.parameters(), self.teacher.parameters()): param_teacher.detach() param_teacher.data.copy_(param_encoder.data) param_teacher.requires_grad = False def momentum_update(self, base_momentum=0): """Momentum update of the teacher network.""" for param_encoder, param_teacher in zip(self.encoder.parameters(), self.teacher.parameters()): param_teacher.data = param_teacher.data * base_momentum + \ param_encoder.data * (1. - base_momentum) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) ''' Input shape: x: [bs, 3, 224, 224] bool_masked_pos: [bs, num_patch * num_patch] ''' def forward(self, x, bool_masked_pos, return_all_tokens=None): batch_size = x.size(0) ''' Encoder Output shape: [bs, num_visible + 1, C] ''' x_unmasked = self.encoder(x, bool_masked_pos=bool_masked_pos) # encoder to decoder projection if self.encoder_to_decoder is not None: x_unmasked = self.encoder_to_decoder(x_unmasked) x_unmasked = self.encoder_to_decoder_norm(x_unmasked) ''' Alignment constraint ''' with torch.no_grad(): latent_target = self.teacher(x, bool_masked_pos=(~bool_masked_pos)) latent_target = latent_target[:, 1:, :] # remove class token if self.encoder_to_decoder is not None: latent_target = self.encoder_to_decoder_norm(self.encoder_to_decoder(latent_target.detach())) self.momentum_update(self.args.base_momentum) ''' Latent contextual regressor and decoder ''' b, num_visible_plus1, dim = x_unmasked.shape # remove class token x_unmasked = x_unmasked[:, 1:, :] num_masked_patches = self.num_patches - (num_visible_plus1-1) # generate position embeddings. pos_embed = self.encoder.build_2d_sincos_position_embedding(dim, use_cls_token=True).expand(batch_size, self.num_patches+1, dim).cuda(x_unmasked.device) # pos embed for masked patches pos_embed_masked = pos_embed[:,1:][bool_masked_pos].reshape(batch_size, -1, dim) # pos embed for unmasked patches pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim) # masked embedding ''' x_masked = self.mask_token.expand(batch_size, num_masked_patches, -1) logits, latent_pred = self.pretext_neck(x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos) logits = logits.view(-1, logits.shape[2]) return logits, latent_pred, latent_target @register_model def cae_small_patch16_224_8k_vocab(pretrained=False, **kwargs): model = VisionTransformerForMaskedImageModeling( patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.load( kwargs["init_ckpt"], map_location="cpu" ) model.load_state_dict(checkpoint["model"]) return model @register_model def cae_base_patch16_224_8k_vocab(pretrained=False, **kwargs): model = VisionTransformerForMaskedImageModeling( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.load( kwargs["init_ckpt"], map_location="cpu" ) model.load_state_dict(checkpoint["model"]) return model @register_model def cae_large_patch16_224_8k_vocab(pretrained=False, **kwargs): model = VisionTransformerForMaskedImageModeling( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.load( kwargs["init_ckpt"], map_location="cpu" ) model.load_state_dict(checkpoint["model"]) return model ================================================ FILE: models/modeling_cae_helper.py ================================================ import math import time import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from models.modeling_finetune import PatchEmbed, DropPath, Mlp from timm.models.registry import register_model from timm.models.layers import trunc_normal_ as __call_trunc_normal_ def trunc_normal_(tensor, mean=0., std=1.): __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, bool_masked_pos=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) # (B, N_head, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x ''' Modified from Attention() ''' class CrossAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, all_head_dim, bias=False) self.k = nn.Linear(dim, all_head_dim, bias=False) self.v = nn.Linear(dim, all_head_dim, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, bool_masked_pos=None, k=None, v=None): B, N, C = x.shape N_k = k.shape[1] N_v = v.shape[1] q_bias, k_bias, v_bias = None, None, None if self.q_bias is not None: q_bias = self.q_bias k_bias = torch.zeros_like(self.v_bias, requires_grad=False) v_bias = self.v_bias q = F.linear(input=x, weight=self.q.weight, bias=q_bias) q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim) k = F.linear(input=k, weight=self.k.weight, bias=k_bias) k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) v = F.linear(input=v, weight=self.v.weight, bias=v_bias) v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, bool_masked_pos=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), bool_masked_pos)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), bool_masked_pos)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class RegressorBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1_q = norm_layer(dim) self.norm1_k = norm_layer(dim) self.norm1_v = norm_layer(dim) self.norm2_cross = norm_layer(dim) self.cross_attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() mlp_hidden_dim = int(dim * mlp_ratio) self.mlp_cross = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values > 0: self.gamma_1_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2_cross = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1_cross = nn.Parameter(torch.ones((dim)),requires_grad=False) self.gamma_2_cross = nn.Parameter(torch.ones((dim)),requires_grad=False) def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos): x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn(self.norm1_q(x_q + pos_q), bool_masked_pos, k=self.norm1_k(x_kv + pos_k), v=self.norm1_v(x_kv))) x = self.norm2_cross(x) x = x + self.drop_path(self.gamma_2_cross * self.mlp_cross(x)) return x ''' Encoder that extracts representations ''' class VisionTransformerEncoder(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None, use_abs_pos_emb=True, init_std=0.02, args=None, **kwargs): super().__init__() self.num_features = self.embed_dim = embed_dim self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.num_patches = num_patches # generate class token and pos embed self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim, use_cls_token=True) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=None, attn_head_dim=attn_head_dim, ) for i in range(depth)]) self.norm = norm_layer(embed_dim) self.init_std = init_std # init the model trunc_normal_(self.cls_token, std=self.init_std) self.apply(self._init_weights) # rescale init function from beit # if it is not activated, it will be overwritten self.fix_init_weight() def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000., use_cls_token=False): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] if not use_cls_token: pos_embed = nn.Parameter(pos_emb) else: pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=self.init_std) if m.bias is not None: nn.init.constant_(m.bias, 0) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def get_num_layers(self): return len(self.blocks) def forward_features(self, x, bool_masked_pos): x = self.patch_embed(x, bool_masked_pos=bool_masked_pos) batch_size, seq_len, dim = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # unmasked embeddings x_unmasked = x[~bool_masked_pos].reshape(batch_size, -1, dim) x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) if self.pos_embed is not None: pos_embed = self.pos_embed.expand(batch_size, self.num_patches+1, dim) pos_embed_unmasked = pos_embed[:,1:][~bool_masked_pos].reshape(batch_size, -1, dim) pos_embed_unmasked = torch.cat((pos_embed[:,:1], pos_embed_unmasked),dim=1) x_unmasked = x_unmasked + pos_embed_unmasked x_unmasked = self.pos_drop(x_unmasked) for blk in self.blocks: x_unmasked = blk(x_unmasked, bool_masked_pos) x_unmasked = self.norm(x_unmasked) return x_unmasked def forward(self, x, bool_masked_pos, return_all_tokens=False): x = self.forward_features(x, bool_masked_pos=bool_masked_pos) return x ''' Latent context regressor + decoder that solves the pretext task. ''' class VisionTransformerNeck(nn.Module): def __init__(self, patch_size=16, num_classes=8192, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, init_values=None, num_patches=196, init_std=0.02, args=None, patch_shape=(14,14)): super().__init__() self.num_features = self.embed_dim = embed_dim self.patch_size = patch_size self.args = args dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # context regressor self.regressor_blocks = nn.ModuleList([ RegressorBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values) for i in range(depth)]) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, args.decoder_depth)] self.decoder_blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values) for i in range(args.decoder_depth)]) self.norm = norm_layer(embed_dim) self.norm2 = norm_layer(embed_dim) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.init_std = init_std # init the model trunc_normal_(self.head.weight, std=self.init_std) self.apply(self._init_weights) self.fix_init_weight() def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.regressor_blocks): rescale(layer.cross_attn.proj.weight.data, layer_id + 1) rescale(layer.mlp_cross.fc2.weight.data, layer_id + 1) for layer_id, layer in enumerate(self.decoder_blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=self.init_std) if m.bias is not None: nn.init.constant_(m.bias, 0) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def forward(self, x_masked, x_unmasked, pos_embed_masked, pos_embed_unmasked, bool_masked_pos): # latent contextual regressor for blk in self.regressor_blocks: x_masked = blk(x_masked, torch.cat([x_unmasked, x_masked], dim=1), pos_embed_masked, torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1), bool_masked_pos) x_masked = self.norm(x_masked) latent_pred = x_masked x_masked = x_masked + pos_embed_masked # add pos embed, like encoder for blk in self.decoder_blocks: x_masked = blk(x_masked) x_masked = self.norm2(x_masked) logits = self.head(x_masked) return logits, latent_pred ================================================ FILE: models/modeling_discrete_vae.py ================================================ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on OpenAI DALL-E and lucidrains' DALLE-pytorch code bases # https://github.com/openai/DALL-E # https://github.com/lucidrains/DALLE-pytorch # --------------------------------------------------------' from math import sqrt import os import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange def top_k(logits, thres = 0.5): num_logits = logits.shape[-1] k = max(int((1 - thres) * num_logits), 1) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs def exists(val): return val is not None def default(val, d): return val if exists(val) else d def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner class BasicVAE(nn.Module): def get_codebook_indices(self, images): raise NotImplementedError() def decode(self, img_seq): raise NotImplementedError() def get_codebook_probs(self, img_seq): raise NotImplementedError() def get_image_tokens_size(self): pass def get_image_size(self): pass class ResBlock(nn.Module): def __init__(self, chan): super().__init__() self.net = nn.Sequential( nn.Conv2d(chan, chan, 3, padding=1), nn.ReLU(), nn.Conv2d(chan, chan, 3, padding=1), nn.ReLU(), nn.Conv2d(chan, chan, 1) ) def forward(self, x): return self.net(x) + x class DiscreteVAE(BasicVAE): def __init__( self, image_size = 256, num_tokens = 512, codebook_dim = 512, num_layers = 3, num_resnet_blocks = 2, hidden_dim = 64, channels = 3, smooth_l1_loss = False, temperature = 0.9, straight_through = False, kl_div_loss_weight = 0., ): super().__init__() # assert log2(image_size).is_integer(), 'image size must be a power of 2' assert num_layers >= 1, 'number of layers must be greater than or equal to 1' has_resblocks = num_resnet_blocks > 0 self.image_size = image_size self.num_tokens = num_tokens self.num_layers = num_layers self.temperature = temperature self.straight_through = straight_through self.codebook = nn.Embedding(num_tokens, codebook_dim) hdim = hidden_dim enc_chans = [hidden_dim] * num_layers dec_chans = list(reversed(enc_chans)) enc_chans = [channels, *enc_chans] dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) enc_layers = [] dec_layers = [] for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU())) dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU())) for _ in range(num_resnet_blocks): dec_layers.insert(0, ResBlock(dec_chans[1])) enc_layers.append(ResBlock(enc_chans[-1])) if num_resnet_blocks > 0: dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1)) enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1)) dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1)) self.encoder = nn.Sequential(*enc_layers) self.decoder = nn.Sequential(*dec_layers) self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss self.kl_div_loss_weight = kl_div_loss_weight def get_image_size(self): return self.image_size def get_image_tokens_size(self): return self.image_size // 8 @torch.no_grad() @eval_decorator def get_codebook_indices(self, images): logits = self.forward(images, return_logits = True) codebook_indices = logits.argmax(dim = 1).flatten(1) return codebook_indices @torch.no_grad() @eval_decorator def get_codebook_probs(self, images, temp): logits = self.forward(images, return_logits = True) return nn.Softmax(dim=1)(logits / temp) def decode( self, img_seq ): image_embeds = self.codebook(img_seq) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w) images = self.decoder(image_embeds) return images def forward( self, img, return_loss = False, return_recons = False, return_logits = False, temp = None ): device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}' logits = self.encoder(img) if return_logits: return logits # return logits for getting hard image indices for DALL-E training temp = default(temp, self.temperature) soft_one_hot = F.gumbel_softmax(logits.float(), tau = temp, dim = 1, hard = self.straight_through) sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight).type_as(logits) out = self.decoder(sampled) if not return_loss: return out # reconstruction loss recon_loss = self.loss_fn(img, out) # kl divergence logits = rearrange(logits, 'b n h w -> b (h w) n') _C = logits.size(-1) avg_probs = F.softmax(logits.contiguous().view(-1, _C), dim=-1, dtype=torch.float32).mean(0) diversity_loss = torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=-1).mean() if not return_recons: return recon_loss, diversity_loss return recon_loss, diversity_loss, out from dall_e import load_model class Dalle_VAE(BasicVAE): def __init__(self, image_size): super().__init__() self.encoder = None self.decoder = None self.image_size = image_size def load_model(self, model_dir, device): self.encoder = load_model(os.path.join(model_dir, "encoder.pkl"), device) self.decoder = load_model(os.path.join(model_dir, "decoder.pkl"), device) def decode(self, img_seq): bsz = img_seq.size()[0] img_seq = img_seq.view(bsz, self.image_size // 8, self.image_size // 8) z = F.one_hot(img_seq, num_classes=self.encoder.vocab_size).permute(0, 3, 1, 2).float() return self.decoder(z).float() def get_codebook_indices(self, images): z_logits = self.encoder(images) return torch.argmax(z_logits, axis=1) def get_codebook_probs(self, images): z_logits = self.encoder(images) return nn.Softmax(dim=1)(z_logits) def forward(self, img_seq_prob, no_process=False): if no_process: return self.decoder(img_seq_prob.float()).float() else: bsz, seq_len, num_class = img_seq_prob.size() z = img_seq_prob.view(bsz, self.image_size // 8, self.image_size // 8, self.encoder.vocab_size) return self.decoder(z.permute(0, 3, 1, 2).float()).float() class VGGAN(BasicVAE): def __init__(self, image_size): super().__init__() self.encoder = None self.decoder = None self.image_size = image_size def load_model(self, weight_path, device): self.vqgan = torch.load(weight_path, map_location=device) def get_codebook_indices(self, images): _, _, [_, _, indices] = self.vqgan.encode(images) # indices: [b, h//8, w//8] return indices ================================================ FILE: models/modeling_finetune.py ================================================ import math import numpy as np from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from furnace.utils import LP_BatchNorm from timm.models.layers import drop_path, to_2tuple, trunc_normal_ from timm.models.registry import register_model def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), **kwargs } class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class CrossAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, all_head_dim, bias=False) self.k = nn.Linear(dim, all_head_dim, bias=False) self.v = nn.Linear(dim, all_head_dim, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, bool_masked_pos=None, k=None, v=None): B, N, C = x.shape N_k = k.shape[1] N_v = v.shape[1] q_bias, k_bias, v_bias = None, None, None if self.q_bias is not None: q_bias = self.q_bias k_bias = torch.zeros_like(self.v_bias, requires_grad=False) v_bias = self.v_bias q = F.linear(input=x, weight=self.q.weight, bias=q_bias) q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim) k = F.linear(input=k, weight=self.k.weight, bias=k_bias) k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) v = F.linear(input=v, weight=self.v.weight, bias=v_bias) v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class AttentiveBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1_q = norm_layer(dim) self.norm1_k = norm_layer(dim) self.norm1_v = norm_layer(dim) self.norm2_cross = norm_layer(dim) self.cross_attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None): x_q = self.norm1_q(x_q + pos_q) x_k = self.norm1_k(x_kv + pos_k) x_v = self.norm1_v(x_kv) x = self.cross_attn(x_q, k=x_k, v=x_v) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww def get_sinusoid_encoding_table(n_position, d_hid, token=False): ''' Sinusoid position encoding table ''' def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if token: sinusoid_table = np.concatenate([sinusoid_table, np.zeros([1, d_hid])], dim=0) return torch.FloatTensor(sinusoid_table).unsqueeze(0) class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_mean_pooling=True, init_scale=0.001, lin_probe=False, linear_type='standard', args=None): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.use_mean_pooling = use_mean_pooling self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) elif args.sin_pos_emb: # sine-cosine positional embeddings is on the way self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) self.lin_probe = lin_probe self.linear_type = linear_type if lin_probe: if self.linear_type == 'standard': self.fc_norm = None elif self.linear_type == 'attentive': self.query_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.attentive_blocks = nn.ModuleList([ AttentiveBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer, init_values=0) for i in range(1)]) self.fc_norm = LP_BatchNorm(embed_dim, affine=False) else: if use_mean_pooling: self.fc_norm = norm_layer(embed_dim) else: self.fc_norm = None self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.pos_embed is not None and use_abs_pos_emb: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.head.weight, std=.02) self.apply(self._init_weights) self.fix_init_weight() self.head.weight.data.mul_(init_scale) self.head.bias.data.mul_(init_scale) def build_2d_sincos_position_embedding(self, embed_dim=768, temperature=10000.): h, w = self.patch_embed.patch_shape grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] # if self.use_mean_pooling: # pos_embed = nn.Parameter(pos_emb) # else: pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32) pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) pos_embed.requires_grad = False return pos_embed def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x, is_train=True): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: if self.use_abs_pos_emb: x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach() else: x = x + self.pos_embed.expand(batch_size, -1, -1).type_as(x).to(x.device).clone().detach() x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: x = blk(x, rel_pos_bias=rel_pos_bias) x = self.norm(x) # linear probing or attentive probing if self.lin_probe: if self.linear_type == 'standard': return x[:, 0] else: query_tokens = self.query_token.expand(batch_size, -1, -1) for blk in self.attentive_blocks: query_tokens = blk(query_tokens, x, 0, 0, bool_masked_pos=None, rel_pos_bias=None) return self.fc_norm(query_tokens[:, 0, :], is_train=is_train) else: # finetune if self.fc_norm is not None: # use mean pooling t = x[:, 1:, :] return self.fc_norm(t.mean(1)) else: return x[:, 0] def forward(self, x, is_train=True): x = self.forward_features(x, is_train) x = self.head(x) return x @register_model def cae_small_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model @register_model def cae_base_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model @register_model def cae_base_patch16_384(pretrained=False, **kwargs): model = VisionTransformer( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model @register_model def cae_large_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model @register_model def cae_large_patch16_384(pretrained=False, **kwargs): model = VisionTransformer( img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model @register_model def cae_large_patch16_512(pretrained=False, **kwargs): model = VisionTransformer( img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() return model ================================================ FILE: requirements.txt ================================================ torch==1.7.1 torchvision==0.8.2 timm==0.3.2 Pillow blobfile mypy numpy pytest requests einops tensorboardX deepspeed==0.4.0 scipy pytorch-lightning==1.0.8 omegaconf==2.0.0 ================================================ FILE: scripts/cae_base_800e.sh ================================================ tmp_my_name=${0##*/} my_name=${tmp_my_name%.*} OUTPUT_DIR='./output/'$my_name DATA_PATH=/path/to/imagenet1k/train TOKENIZER_PATH=./tokenizer-weights ADDRESS=ADDR_FOR_THIS_MACHINE NNODES=4 RANK=RANK_FOR_THIS_MACHINE # ============================ pretraining ============================ OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_pretraining.py \ --data_path ${DATA_PATH} \ --output_dir ${OUTPUT_DIR} \ --model cae_base_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \ --batch_size 64 --lr 1.5e-3 --warmup_epochs 20 --epochs 800 \ --clip_grad 3.0 --layer_scale_init_value 0.1 \ --imagenet_default_mean_and_std \ --color_jitter 0 \ --drop_path 0.1 \ --sincos_pos_emb \ --mask_generator block \ --num_mask_patches 98 \ --decoder_layer_scale_init_value 0.1 \ --no_auto_resume \ --save_ckpt_freq 100 \ --exp_name $my_name \ --regressor_depth 4 \ --decoder_depth 4 \ --align_loss_weight 2 # ============================ linear probing ============================ DATA_PATH=/path/to/imagenet1k/ MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_linear.py \ --model cae_base_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 \ --batch_size 512 \ --epochs 90 \ --blr 0.1 \ --weight_decay 0.0 \ --dist_eval --data_path ${DATA_PATH} \ --output_dir $OUTPUT_DIR \ --log_dir $OUTPUT_DIR \ --enable_linear_eval \ --use_cls \ --dist_eval \ --save_freq 50 \ --disable_rel_pos_bias \ --linear_type standard \ --exp_name $my_name # ============================ attentive probing ============================ DATA_PATH=/path/to/imagenet1k/ MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_attentive.py \ --model cae_base_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \ --output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \ --warmup_epochs 10 --epochs 90 \ --weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \ --color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \ --opt sgd --momentum 0.9 \ --enable_linear_eval \ --use_cls \ --dist_eval \ --no_auto_resume \ --save_ckpt_freq 50 \ --linear_type attentive \ --exp_name $my_name ================================================ FILE: scripts/cae_base_finetune.sh ================================================ tmp_my_name=${0##*/} my_name=${tmp_my_name%.*} OUTPUT_DIR='./output/'$my_name DATA_PATH=/path/to/imagenet1k/train TOKENIZER_PATH=./tokenizer-weights ADDRESS=ADDR_FOR_THIS_MACHINE NNODES=4 RANK=RANK_FOR_THIS_MACHINE MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_class_finetuning.py \ --model cae_base_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 --data_set IMNET \ --output_dir $OUTPUT_DIR \ --batch_size 128 \ --lr 8e-3 --update_freq 1 \ --warmup_epochs 5 --epochs 100 --layer_decay 0.65 --drop_path 0.1 \ --weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \ --sin_pos_emb \ --dist_eval \ --no_auto_resume \ --exp_name $my_name \ --imagenet_default_mean_and_std ================================================ FILE: scripts/cae_large_1600e.sh ================================================ tmp_my_name=${0##*/} my_name=${tmp_my_name%.*} OUTPUT_DIR='./output/'$my_name DATA_PATH=/path/to/imagenet1k/train TOKENIZER_PATH=./tokenizer-weights ADDRESS=ADDR_FOR_THIS_MACHINE NNODES=4 RANK=RANK_FOR_THIS_MACHINE # ============================ pretraining ============================ OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_pretraining.py \ --data_path ${DATA_PATH} \ --output_dir ${OUTPUT_DIR} \ --model cae_large_patch16_224_8k_vocab --discrete_vae_weight_path ${TOKENIZER_PATH} \ --batch_size 64 --lr 1.5e-3 --warmup_epochs 40 --epochs 1600 \ --clip_grad 3.0 --layer_scale_init_value 1e-5 \ --imagenet_default_mean_and_std \ --color_jitter 0 \ --drop_path 0.1 \ --sincos_pos_emb \ --mask_generator block \ --num_mask_patches 98 \ --decoder_layer_scale_init_value 1e-5 \ --no_auto_resume \ --save_ckpt_freq 100 \ --exp_name $my_name \ --regressor_depth 4 \ --decoder_depth 4 \ --align_loss_weight 2 --decoder_embed_dim 1024 \ --decoder_num_heads 16 \ --fix_init_weight # ============================ linear probing ============================ DATA_PATH=/path/to/imagenet1k/ MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_linear.py \ --model cae_large_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 \ --batch_size 512 \ --epochs 90 \ --blr 0.1 \ --weight_decay 0.0 \ --dist_eval --data_path ${DATA_PATH} \ --output_dir $OUTPUT_DIR \ --log_dir $OUTPUT_DIR \ --enable_linear_eval \ --use_cls \ --dist_eval \ --save_freq 90 \ --disable_rel_pos_bias \ --linear_type standard \ --exp_name $my_name # ============================ attentive probing ============================ DATA_PATH=/path/to/imagenet1k/ MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_attentive.py \ --model cae_large_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 --data_set IMNET --imagenet_default_mean_and_std \ --output_dir $OUTPUT_DIR --batch_size 256 --lr 0.4 --update_freq 1 \ --warmup_epochs 10 --epochs 90 \ --weight_decay 0 --smoothing 0.0 --layer_decay 1.0 --drop_path 0.0 \ --color_jitter 0.0 --mixup 0.0 --cutmix 0.0 --reprob 0.0 \ --opt sgd --momentum 0.9 \ --enable_linear_eval \ --use_cls \ --dist_eval \ --no_auto_resume \ --save_ckpt_freq 50 \ --linear_type attentive \ --exp_name $my_name ================================================ FILE: scripts/cae_large_finetune.sh ================================================ tmp_my_name=${0##*/} my_name=${tmp_my_name%.*} OUTPUT_DIR='./output/'$my_name DATA_PATH=/path/to/imagenet1k/train TOKENIZER_PATH=./tokenizer-weights ADDRESS=ADDR_FOR_THIS_MACHINE NNODES=4 RANK=RANK_FOR_THIS_MACHINE MODEL_PATH=/path/to/pretrained/model OMP_NUM_THREADS=1 python -m torch.distributed.launch \ --nproc_per_node=8 \ --nnodes=$NNODES \ --node_rank=$RANK \ --master_addr=$ADDRESS \ --master_port=8899 \ tools/run_class_finetuning.py \ --model cae_large_patch16_224 --data_path $DATA_PATH \ --finetune $MODEL_PATH \ --nb_classes 1000 --data_set IMNET \ --output_dir $OUTPUT_DIR \ --batch_size 64 \ --lr 2e-3 --update_freq 2 \ --warmup_epochs 5 --epochs 50 --layer_decay 0.75 --drop_path 0.2 \ --weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \ --sin_pos_emb \ --dist_eval \ --no_auto_resume \ --exp_name $my_name \ --imagenet_default_mean_and_std ================================================ FILE: tokenizer-weights/README ================================================ -- tokenizers ================================================ FILE: tools/run_attentive.py ================================================ import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json import os from pathlib import Path from timm.data.mixup import Mixup from timm.models import create_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import ModelEma from furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner from furnace.datasets import build_dataset from furnace.engine_for_finetuning import train_one_epoch, evaluate from furnace.utils import NativeScalerWithGradNormCount as NativeScaler import furnace.utils as utils from scipy import interpolate import models.modeling_finetune def get_args(): parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False) parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--epochs', default=30, type=int) parser.add_argument('--update_freq', default=1, type=int) parser.add_argument('--save_ckpt_freq', default=5, type=int) # Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--rel_pos_bias', action='store_true') parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') parser.set_defaults(rel_pos_bias=True) parser.add_argument('--abs_pos_emb', action='store_true') parser.set_defaults(abs_pos_emb=False) parser.add_argument('--sin_pos_emb', action='store_true') parser.set_defaults(sin_pos_emb=True) parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb') parser.add_argument('--layer_scale_init_value', default=0.1, type=float, help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") parser.add_argument('--input_size', default=224, type=int, help='images input size') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', help='Attention dropout rate (default: 0.)') parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)') parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) parser.add_argument('--model_ema', action='store_true', default=False) parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by the end of training improves performance for ViTs.""") parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') parser.add_argument('--layer_decay', type=float, default=0.9) parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', help='num of steps to warmup LR, will overload warmup_epochs if set > 0') # Augmentation parameters parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default='', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--train_interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') # Evaluation parameters parser.add_argument('--crop_pct', type=float, default=None) # * Random Erase params parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') # * Mixup params parser.add_argument('--mixup', type=float, default=0, help='mixup alpha, mixup enabled if > 0.') parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha, cutmix enabled if > 0.') parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') parser.add_argument('--model_key', default='model|module|state_dict', type=str) parser.add_argument('--model_prefix', default='', type=str) parser.add_argument('--init_scale', default=0.001, type=float) parser.add_argument('--use_mean_pooling', action='store_true') parser.set_defaults(use_mean_pooling=True) parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False) # Dataset parameters parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') parser.add_argument('--eval_data_path', default=None, type=str, help='dataset path for evaluation') parser.add_argument('--nb_classes', default=0, type=int, help='number of the classification types') parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'], type=str, help='ImageNet dataset path') parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default=None, help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') parser.set_defaults(auto_resume=True) parser.add_argument('--save_ckpt', action='store_true') parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') parser.set_defaults(save_ckpt=True) parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--enable_deepspeed', action='store_true', default=False) parser.add_argument('--enable_linear_eval', action='store_true', default=False) parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something') parser.add_argument('--linear_type', default='standard', type=str, help='standard, attentive') parser.add_argument('--exp_name', default='', type=str, help='name of exp. it is helpful when save the checkpoint') known_args, _ = parser.parse_known_args() if known_args.enable_deepspeed: try: import deepspeed from deepspeed import DeepSpeedConfig parser = deepspeed.add_config_arguments(parser) ds_init = deepspeed.initialize except: print("Please 'pip install deepspeed==0.4.0'") exit(0) else: ds_init = None return parser.parse_args(), ds_init def main(args, ds_init): if not args.enable_linear_eval: args.aa = 'rand-m9-mstd0.5-inc1' utils.init_distributed_mode(args) if ds_init is not None: utils.create_ds_config(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) if args.disable_eval_during_finetuning: dataset_val = None else: dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) if global_rank == 0 and args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) log_writer = utils.TensorboardLogger(log_dir=args.log_dir) else: log_writer = None data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) if dataset_val is not None: data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) else: data_loader_val = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: print("Mixup is activated!") mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate, drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale, use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, lin_probe=args.enable_linear_eval, linear_type=args.linear_type, args=args, ) if args.enable_linear_eval: linear_keyword = 'head' head_norm = 'fc_norm' parameters_requires_grad = [] for name, param in model.named_parameters(): param.requires_grad = False # no grad by default if 'gamma' in name: param.requires_grad = False else: if ('query_token' in name) or ('attentive_blocks' in name) or (name in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]) or (head_norm in name): parameters_requires_grad.append(name) param.requires_grad = True print(f'parameters that need grad: ', parameters_requires_grad) getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) getattr(model, linear_keyword).bias.data.zero_() patch_size = model.patch_embed.patch_size print("Patch size = %s" % str(patch_size)) args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) args.patch_size = patch_size if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') print("Load ckpt from %s" % args.finetune) checkpoint_model = None for model_key in args.model_key.split('|'): if model_key in checkpoint: checkpoint_model = checkpoint[model_key] print("Load state_dict by model_key = %s" % model_key) break if checkpoint_model is None: checkpoint_model = checkpoint state_dict = model.state_dict() original_all_keys = list(checkpoint_model.keys()) print("##########origin keys:", len(original_all_keys), original_all_keys) # NOTE: remove all decoder keys all_keys = [key for key in original_all_keys if key.startswith('encoder.')] print("all keys:", all_keys) for key in all_keys: new_key = key.replace('encoder.','') # print("new_key:", new_key) checkpoint_model[new_key] = checkpoint_model[key] checkpoint_model.pop(key) # handle moco-v3 checkpoints all_keys = [key for key in original_all_keys if key.startswith('module.base_encoder.')] print("all keys:", all_keys) for key in all_keys: new_key = key.replace('module.base_encoder.','') # print("new_key:", new_key) checkpoint_model[new_key] = checkpoint_model[key] checkpoint_model.pop(key) for key in list(checkpoint_model.keys()): if key.startswith('decoder.'): # print("key:", key) checkpoint_model.pop(key) if key.startswith('teacher.'): # print("key:", key) checkpoint_model.pop(key) for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: print("Expand the shared relative position embedding to each transformer block. ") num_layers = model.get_num_layers() rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_index" in key: checkpoint_model.pop(key) if "relative_position_bias_table" in key and args.rel_pos_bias: rel_pos_bias = checkpoint_model[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) print("Original positions = %s" % str(x)) print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) checkpoint_model[key] = new_rel_pos_bias print("##############new keys:", len(checkpoint_model), checkpoint_model.keys()) #print("##############model:", model) # interpolate position embedding if 'pos_embed' in checkpoint_model and args.abs_pos_emb: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) # model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') print("Using EMA with decay = %.8f" % args.model_ema_decay) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Model = %s" % str(model_without_ddp)) print('number of params:', n_parameters) total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() num_training_steps_per_epoch = len(dataset_train) // total_batch_size print("LR = %.8f" % args.lr) print("Batch size = %d" % total_batch_size) print("Update frequent = %d" % args.update_freq) print("Number of training examples = %d" % len(dataset_train)) print("Number of training training per epoch = %d" % num_training_steps_per_epoch) num_layers = model_without_ddp.get_num_layers() if args.layer_decay < 1.0: assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) else: assigner = None if assigner is not None: print("Assigned values = %s" % str(assigner.values)) skip_weight_decay_list = model.no_weight_decay() print("Skip weight decay list: ", skip_weight_decay_list) if args.disable_weight_decay_on_rel_pos_bias: for i in range(num_layers): skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i) if args.enable_deepspeed: loss_scaler = None optimizer_params = get_parameter_groups( model, args.weight_decay, skip_weight_decay_list, assigner.get_layer_id if assigner is not None else None, assigner.get_scale if assigner is not None else None) model, optimizer, _, _ = ds_init( args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed, ) print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) assert model.gradient_accumulation_steps() == args.update_freq else: if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module optimizer = create_optimizer( args, model_without_ddp, skip_list=skip_weight_decay_list, get_num_layer=assigner.get_layer_id if assigner is not None else None, get_layer_scale=assigner.get_scale if assigner is not None else None) loss_scaler = NativeScaler() print("Use step level LR scheduler!") lr_schedule_values = utils.cosine_scheduler( args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, ) if args.weight_decay_end is None: args.weight_decay_end = args.weight_decay wd_schedule_values = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) if mixup_fn is not None: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() print("criterion = %s" % str(criterion)) utils.auto_load_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") exit(0) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if log_writer is not None: log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, ) if args.output_dir and args.save_ckpt: if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: utils.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name, model_ema=model_ema) if data_loader_val is not None: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") if max_accuracy < test_stats["acc1"]: max_accuracy = test_stats["acc1"] if args.output_dir and args.save_ckpt: utils.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) print(f'Max accuracy: {max_accuracy:.2f}%') if log_writer is not None: log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} else: log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, # **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': opts, ds_init = get_args() if opts.output_dir: Path(opts.output_dir).mkdir(parents=True, exist_ok=True) main(opts, ds_init) ================================================ FILE: tools/run_class_finetuning.py ================================================ import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json import os import shutil from pathlib import Path from timm.data.mixup import Mixup from timm.models import create_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import ModelEma from furnace.optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner from furnace.datasets import build_dataset from furnace.engine_for_finetuning import train_one_epoch, evaluate from furnace.utils import NativeScalerWithGradNormCount as NativeScaler import furnace.utils as utils from scipy import interpolate import models.modeling_finetune def get_args(): parser = argparse.ArgumentParser('fine-tuning and evaluation script for image classification', add_help=False) parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--epochs', default=30, type=int) parser.add_argument('--update_freq', default=1, type=int) parser.add_argument('--save_ckpt_freq', default=5, type=int) # Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--rel_pos_bias', action='store_true') parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') parser.set_defaults(rel_pos_bias=True) parser.add_argument('--abs_pos_emb', action='store_true') parser.set_defaults(abs_pos_emb=False) parser.add_argument('--sin_pos_emb', action='store_true') parser.set_defaults(sin_pos_emb=True) parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb') parser.add_argument('--layer_scale_init_value', default=0.1, type=float, help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") parser.add_argument('--input_size', default=224, type=int, help='images input size') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', help='Attention dropout rate (default: 0.)') parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)') parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) parser.add_argument('--model_ema', action='store_true', default=False) parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by the end of training improves performance for ViTs.""") parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') parser.add_argument('--layer_decay', type=float, default=0.9) parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', help='num of steps to warmup LR, will overload warmup_epochs if set > 0') # Augmentation parameters parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default='', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--train_interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') # Evaluation parameters parser.add_argument('--crop_pct', type=float, default=None) # * Random Erase params parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') # * Mixup params parser.add_argument('--mixup', type=float, default=0, help='mixup alpha, mixup enabled if > 0.') parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha, cutmix enabled if > 0.') parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') parser.add_argument('--model_key', default='model|module|state_dict', type=str) parser.add_argument('--model_prefix', default='', type=str) parser.add_argument('--init_scale', default=0.001, type=float) parser.add_argument('--use_mean_pooling', action='store_true') parser.set_defaults(use_mean_pooling=True) parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False) # Dataset parameters parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') parser.add_argument('--eval_data_path', default=None, type=str, help='dataset path for evaluation') parser.add_argument('--nb_classes', default=0, type=int, help='number of the classification types') parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET100', 'image_folder'], type=str, help='ImageNet dataset path') parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default=None, help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') parser.set_defaults(auto_resume=True) parser.add_argument('--save_ckpt', action='store_true') parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') parser.set_defaults(save_ckpt=True) parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--enable_deepspeed', action='store_true', default=False) parser.add_argument('--enable_linear_eval', action='store_true', default=False) parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu prints something') parser.add_argument('--exp_name', default='', type=str, help='name of exp. it is helpful when save the checkpoint') known_args, _ = parser.parse_known_args() if known_args.enable_deepspeed: try: import deepspeed from deepspeed import DeepSpeedConfig parser = deepspeed.add_config_arguments(parser) ds_init = deepspeed.initialize except: print("Please 'pip install deepspeed==0.4.0'") exit(0) else: ds_init = None return parser.parse_args(), ds_init def main(args, ds_init): if not args.enable_linear_eval: args.aa = 'rand-m9-mstd0.5-inc1' utils.init_distributed_mode(args) if ds_init is not None: utils.create_ds_config(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) if args.disable_eval_during_finetuning: dataset_val = None else: dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) if global_rank == 0 and args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) log_writer = utils.TensorboardLogger(log_dir=args.log_dir) else: log_writer = None data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) if dataset_val is not None: data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) else: data_loader_val = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: print("Mixup is activated!") mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate, drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale, use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, lin_probe=args.enable_linear_eval, args=args, ) if args.enable_linear_eval: # freeze all layers but the last fc linear_keyword = 'head' head_norm = 'fc_norm' requires_grad = [] for name, param in model.named_parameters(): if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword] and head_norm not in name: param.requires_grad = False else: requires_grad.append(name) print(f'require grad parameter: ', requires_grad) # init the fc layer getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) getattr(model, linear_keyword).bias.data.zero_() patch_size = model.patch_embed.patch_size print("Patch size = %s" % str(patch_size)) args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) args.patch_size = patch_size if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') print("Load ckpt from %s" % args.finetune) checkpoint_model = None for model_key in args.model_key.split('|'): if model_key in checkpoint: checkpoint_model = checkpoint[model_key] print("Load state_dict by model_key = %s" % model_key) break if checkpoint_model is None: checkpoint_model = checkpoint state_dict = model.state_dict() all_keys = list(checkpoint_model.keys()) print("##########origin keys:", len(all_keys), all_keys) # NOTE: remove all decoder keys all_keys = [key for key in all_keys if key.startswith('encoder.')] print("all keys:", all_keys) for key in all_keys: new_key = key.replace('encoder.','') # print("new_key:", new_key) checkpoint_model[new_key] = checkpoint_model[key] checkpoint_model.pop(key) for key in list(checkpoint_model.keys()): if key.startswith('decoder.'): # print("key:", key) checkpoint_model.pop(key) if key.startswith('teacher.'): # print("key:", key) checkpoint_model.pop(key) # NOTE: replace norm with fc_norm for key in list(checkpoint_model.keys()): # print("new key:", key) if key.startswith('norm.'): new_key = key.replace('norm.','fc_norm.') checkpoint_model[new_key] = checkpoint_model[key] checkpoint_model.pop(key) for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: print("Expand the shared relative position embedding to each transformer block. ") num_layers = model.get_num_layers() rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_index" in key: checkpoint_model.pop(key) if "relative_position_bias_table" in key and args.rel_pos_bias: rel_pos_bias = checkpoint_model[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) print("Original positions = %s" % str(x)) print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) checkpoint_model[key] = new_rel_pos_bias print("##############new keys:", len(checkpoint_model), checkpoint_model.keys()) #print("##############model:", model) # interpolate position embedding if 'pos_embed' in checkpoint_model and args.abs_pos_emb: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) # model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') print("Using EMA with decay = %.8f" % args.model_ema_decay) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Model = %s" % str(model_without_ddp)) print('number of params:', n_parameters) total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() num_training_steps_per_epoch = len(dataset_train) // total_batch_size print("LR = %.8f" % args.lr) print("Batch size = %d" % total_batch_size) print("Update frequent = %d" % args.update_freq) print("Number of training examples = %d" % len(dataset_train)) print("Number of training training per epoch = %d" % num_training_steps_per_epoch) num_layers = model_without_ddp.get_num_layers() if args.layer_decay < 1.0: assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) else: assigner = None if assigner is not None: print("Assigned values = %s" % str(assigner.values)) skip_weight_decay_list = model.no_weight_decay() print("Skip weight decay list: ", skip_weight_decay_list) if args.disable_weight_decay_on_rel_pos_bias: for i in range(num_layers): skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i) if args.enable_deepspeed: loss_scaler = None optimizer_params = get_parameter_groups( model, args.weight_decay, skip_weight_decay_list, assigner.get_layer_id if assigner is not None else None, assigner.get_scale if assigner is not None else None) model, optimizer, _, _ = ds_init( args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed, ) print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) assert model.gradient_accumulation_steps() == args.update_freq else: if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module optimizer = create_optimizer( args, model_without_ddp, skip_list=skip_weight_decay_list, get_num_layer=assigner.get_layer_id if assigner is not None else None, get_layer_scale=assigner.get_scale if assigner is not None else None) loss_scaler = NativeScaler() print("Use step level LR scheduler!") lr_schedule_values = utils.cosine_scheduler( args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, ) if args.weight_decay_end is None: args.weight_decay_end = args.weight_decay wd_schedule_values = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) if mixup_fn is not None: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() print("criterion = %s" % str(criterion)) utils.auto_load_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") exit(0) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if log_writer is not None: log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, ) if args.output_dir and args.save_ckpt: if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: utils.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) if data_loader_val is not None: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") if max_accuracy < test_stats["acc1"]: max_accuracy = test_stats["acc1"] if args.output_dir and args.save_ckpt: utils.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) print(f'Max accuracy: {max_accuracy:.2f}%') if log_writer is not None: log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} else: log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, # **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': opts, ds_init = get_args() if opts.output_dir: Path(opts.output_dir).mkdir(parents=True, exist_ok=True) main(opts, ds_init) ================================================ FILE: tools/run_linear.py ================================================ import argparse import datetime import json import numpy as np import os import time from pathlib import Path import torch import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import torchvision.datasets as datasets import timm assert timm.__version__ == "0.3.2" # version check from timm.models.layers import trunc_normal_ import linear_util.misc as misc from linear_util.pos_embed import interpolate_pos_embed from linear_util.misc import NativeScalerWithGradNormCount as NativeScaler from linear_util.lars import LARS from linear_util.crop import RandomResizedCrop import models.modeling_finetune as models_vit from linear_util.engine_finetune import train_one_epoch, evaluate def setup_for_distributed(rank): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): if rank==0: builtin_print(*args, **kwargs) __builtin__.print = print def get_args_parser(): parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) parser.add_argument('--batch_size', default=512, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=90, type=int) parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') # Model parameters parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train') # Optimizer parameters parser.add_argument('--weight_decay', type=float, default=0, help='weight decay (default: 0 for linear probe following MoCo v1)') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=0.1, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', help='epochs to warmup LR') # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') # Dataset parameters parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types') parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='./output_dir', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation (recommended during training for faster monitor') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', help='Attention dropout rate (default: 0.)') parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', help='Drop path rate (default: 0.1)') parser.add_argument('--init_scale', default=0.001, type=float) parser.add_argument('--use_mean_pooling', action='store_true') parser.set_defaults(use_mean_pooling=True) parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') parser.add_argument('--rel_pos_bias', action='store_true') parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') parser.set_defaults(rel_pos_bias=True) parser.add_argument('--abs_pos_emb', action='store_true') parser.set_defaults(abs_pos_emb=False) parser.add_argument('--sin_pos_emb', action='store_true') parser.set_defaults(sin_pos_emb=True) parser.add_argument('--disable_sin_pos_emb', action='store_false', dest='sin_pos_emb') parser.add_argument('--layer_scale_init_value', default=0.1, type=float, help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") parser.add_argument('--enable_linear_eval', action='store_true', default=False) parser.add_argument('--exp_name', default='', type=str, help='name of exp. it is helpful when save the checkpoint') parser.add_argument('--save_freq', default=50, type=int, help='freq of saving models') parser.add_argument('--linear_type', default='standard', type=str, help='standard or attentive') return parser def main(args): misc.init_distributed_mode(args) setup_for_distributed(args.local_rank) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # linear probe: weak augmentation transform_train = transforms.Compose([ RandomResizedCrop(224, interpolation=3), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) transform_val = transforms.Compose([ transforms.Resize(256, interpolation=3), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) print(dataset_train) print(dataset_val) if True: # args.distributed: num_tasks = misc.get_world_size() global_rank = misc.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) if global_rank == 0 and args.log_dir is not None and not args.eval: os.makedirs(args.log_dir, exist_ok=True) # log_writer = SummaryWriter(log_dir=args.log_dir) log_writer = None else: log_writer = None data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) model = models_vit.__dict__[args.model]( num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale, use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, lin_probe=args.enable_linear_eval, args=args, ) if args.finetune and not args.eval: checkpoint = torch.load(args.finetune, map_location='cpu') print("Load pre-trained checkpoint from: %s" % args.finetune) checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] for key in list(checkpoint_model.keys()): if 'encoder.' in key: new_key = key.replace('encoder.','') checkpoint_model[new_key] = checkpoint_model[key] checkpoint_model.pop(key) if 'teacher' in key or 'decoder' in key: checkpoint_model.pop(key) if args.rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: print("Expand the shared relative position embedding to each transformer block. ") num_layers = model.get_num_layers() rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] for i in range(num_layers): checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_index" in key: checkpoint_model.pop(key) if "relative_position_bias_table" in key and args.rel_pos_bias: rel_pos_bias = checkpoint_model[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.state_dict()[key].size() dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) print("Original positions = %s" % str(x)) print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) checkpoint_model[key] = new_rel_pos_bias # interpolate position embedding interpolate_pos_embed(model, checkpoint_model) # load pre-trained model msg = model.load_state_dict(checkpoint_model, strict=False) print(msg) trunc_normal_(model.head.weight, std=0.01) model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) requires_grad = [] for _, p in model.named_parameters(): p.requires_grad = False for nname, p in model.head.named_parameters(): p.requires_grad = True requires_grad.append(nname) print(f'require grad parameter: ', requires_grad) model.to(device) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Model = %s" % str(model_without_ddp)) print('number of params (M): %.2f' % (n_parameters / 1.e6)) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 256 print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) print(optimizer) loss_scaler = NativeScaler() criterion = torch.nn.CrossEntropyLoss() print("criterion = %s" % str(criterion)) misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") exit(0) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, max_norm=None, log_writer=log_writer, args=args ) if args.output_dir and (epoch % args.save_freq == 0 or epoch + 1 == args.epochs): misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') if log_writer is not None: log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and misc.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args) ================================================ FILE: tools/run_pretraining.py ================================================ import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json import os import shutil from pathlib import Path from timm.models import create_model from furnace.optim_factory import create_optimizer from furnace.datasets import build_cae_pretraining_dataset from furnace.engine_for_pretraining import train_one_epoch from furnace.utils import NativeScalerWithGradNormCount as NativeScaler import furnace.utils as utils from models import modeling_cae import torch.distributed as dist def get_args(): parser = argparse.ArgumentParser('pre-training script', add_help=False) parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--epochs', default=300, type=int) parser.add_argument('--save_ckpt_freq', default=50, type=int) parser.add_argument("--discrete_vae_weight_path", type=str) parser.add_argument("--discrete_vae_type", type=str, default="dall-e", help='[dall-e, vqgan_gumbel_f8_8192, customized]') parser.add_argument('--dvae_num_layers', default=3, type=int) # Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--rel_pos_bias', action='store_true', default=False) parser.add_argument('--abs_pos_emb', action='store_true', default=False) parser.add_argument('--sincos_pos_emb', action='store_true', default=False) parser.add_argument('--layer_scale_init_value', default=0.1, type=float, help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") parser.add_argument('--input_size', default=224, type=int, help='images input size for backbone') parser.add_argument('--second_input_size', default=112, type=int, help='images input size for discrete vae') parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', help='Drop path rate (default: 0)') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: 0.9, 0.98, use opt default)') parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the weight decay. We use a cosine schedule for WD. (Set the same value with args.weight_decay to keep weight decay no change)""") parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', help='epochs to warmup LR, if scheduler supports') # Augmentation parameters parser.add_argument('--train_interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') parser.add_argument('--second_interpolation', type=str, default='lanczos', help='Interpolation for discrete vae (random, bilinear, bicubic default: "lanczos")') # Dataset parameters parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default=None, help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') parser.set_defaults(auto_resume=True) parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', help='') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--exp_name', default='', type=str, help='it is used when save the checkpoint') parser.add_argument('--enable_multi_print', action='store_true',default=False, help='allow each gpu to print something') ''' Data augmentation ''' # crop size parser.add_argument('--crop_min_size', type=float, default=0.08, help='min size of crop') parser.add_argument('--crop_max_size', type=float, default=1.0, help='max size of crop') # color jitter parser.add_argument('--color_jitter', type=float, default=0, metavar='PCT', help='Color jitter factor (default: 0)') ''' Mask strategy ''' parser.add_argument('--mask_generator', default='block', type=str, help='block or random') # 1. if use block mask, set the num_mask_patches parser.add_argument('--num_mask_patches', default=98, type=int, help='number of the visual tokens/patches need be masked') parser.add_argument('--max_mask_patches_per_block', type=int, default=None) parser.add_argument('--min_mask_patches_per_block', type=int, default=16) # 2. if use random mask, set the mask ratio parser.add_argument('--ratio_mask_patches', default=None, type=float, help="mask ratio") ''' CAE hyper-parameters ''' parser.add_argument('--regressor_depth', default=4, type=int, help='depth of the regressor') parser.add_argument('--decoder_depth', default=4, type=int, help='depth of the decoder') parser.add_argument('--decoder_embed_dim', default=768, type=int, help='dimensionaltiy of embeddings for decoder') parser.add_argument('--decoder_num_heads', default=12, type=int, help='Number of heads for decoder') parser.add_argument('--decoder_num_classes', default=8192, type=int, help='Number of classes for decoder') parser.add_argument('--decoder_layer_scale_init_value', default=0.1, type=float, help='decoder layer scale init value') # alignment constraint parser.add_argument('--align_loss_weight', type=float, default=2, help='loss weight for the alignment constraint') parser.add_argument('--base_momentum', type=float, default=0, help='ema weight for the dual path network') # init func, borrowed from BEiT parser.add_argument('--fix_init_weight', action='store_true', default=False, help='if true, the fix_init_weight() func will be activated') return parser.parse_args() def get_model(args): print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, drop_path_rate=args.drop_path, drop_block_rate=None, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, args=args, ) return model def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True model = get_model(args) patch_size = model.encoder.patch_embed.patch_size print("Patch size = %s" % str(patch_size)) args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) args.patch_size = patch_size # get dataset dataset_train = build_cae_pretraining_dataset(args) # prepare discrete vae d_vae = utils.create_d_vae( weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type, device=device, image_size=args.second_input_size, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() sampler_rank = global_rank num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) if global_rank == 0 and args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) log_writer = utils.TensorboardLogger(log_dir=args.log_dir) else: log_writer = None data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) model.to(device) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Model = %s" % str(model_without_ddp)) print('number of params:', n_parameters) total_batch_size = args.batch_size * utils.get_world_size() print("LR = %.8f" % args.lr) print("Batch size = %d" % total_batch_size) print("Number of training steps = %d" % num_training_steps_per_epoch) print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch)) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module optimizer = create_optimizer( args, model_without_ddp) loss_scaler = NativeScaler() print("Use step level LR & WD scheduler!") lr_schedule_values = utils.cosine_scheduler( args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, ) if args.weight_decay_end is None: args.weight_decay_end = args.weight_decay wd_schedule_values = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) utils.auto_load_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if log_writer is not None: log_writer.set_step(epoch * num_training_steps_per_epoch) train_stats = train_one_epoch( model, d_vae, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, args=args, ) if args.output_dir: if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: utils.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, exp_name=args.exp_name) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': opts = get_args() if opts.output_dir: Path(opts.output_dir).mkdir(parents=True, exist_ok=True) main(opts)