Repository: LMMMEng/OverLoCK
Branch: main
Commit: 2c8ab3b29e3a
Files: 64
Total size: 334.9 KB
Directory structure:
gitextract__a0ogxjt/
├── LICENSE.md
├── README.md
├── detection/
│ ├── configs/
│ │ ├── _base_/
│ │ │ ├── datasets/
│ │ │ │ ├── coco_detection.py
│ │ │ │ └── coco_instance.py
│ │ │ ├── default_runtime.py
│ │ │ ├── models/
│ │ │ │ ├── cascade_mask_rcnn_r50_fpn.py
│ │ │ │ ├── cascade_mask_rcnn_r50_fpn_crowdhuman.py
│ │ │ │ ├── cascade_rcnn_r50_fpn.py
│ │ │ │ ├── fast_rcnn_r50_fpn.py
│ │ │ │ ├── faster_rcnn_r50_caffe_c4.py
│ │ │ │ ├── faster_rcnn_r50_caffe_dc5.py
│ │ │ │ ├── faster_rcnn_r50_fpn.py
│ │ │ │ ├── mask_rcnn_convnext_fpn.py
│ │ │ │ ├── mask_rcnn_r50_caffe_c4.py
│ │ │ │ ├── mask_rcnn_r50_fpn.py
│ │ │ │ ├── retinanet_r50_fpn.py
│ │ │ │ ├── rpn_r50_caffe_c4.py
│ │ │ │ ├── rpn_r50_fpn.py
│ │ │ │ └── ssd300.py
│ │ │ └── schedules/
│ │ │ ├── schedule_1x.py
│ │ │ └── schedule_3x.py
│ │ └── maskrcnn_overlock/
│ │ ├── mask_rcnn_overlock_b_in1k_fpn_1x_coco.py
│ │ ├── mask_rcnn_overlock_b_in1k_fpn_3x_coco.py
│ │ ├── mask_rcnn_overlock_s_in1k_fpn_1x_coco.py
│ │ ├── mask_rcnn_overlock_s_in1k_fpn_3x_coco.py
│ │ ├── mask_rcnn_overlock_t_in1k_fpn_1x_coco.py
│ │ └── mask_rcnn_overlock_t_in1k_fpn_3x_coco.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── overlock.py
│ ├── readme.md
│ ├── scripts/
│ │ ├── dist_test.sh
│ │ └── dist_train.sh
│ ├── test.py
│ └── train.py
├── models/
│ ├── __init__.py
│ ├── contmix.py
│ └── overlock.py
├── scripts/
│ ├── train_b_model.sh
│ ├── train_s_model.sh
│ ├── train_t_model.sh
│ └── train_xt_model.sh
├── segmentation/
│ ├── configs/
│ │ ├── _base_/
│ │ │ ├── datasets/
│ │ │ │ └── ade20k.py
│ │ │ ├── default_runtime.py
│ │ │ ├── models/
│ │ │ │ ├── fpn_r50.py
│ │ │ │ ├── upernet_r50.py
│ │ │ │ └── upernet_transnext.py
│ │ │ └── schedules/
│ │ │ ├── schedule_160k.py
│ │ │ ├── schedule_20k.py
│ │ │ ├── schedule_40k.py
│ │ │ └── schedule_80k.py
│ │ └── overlock/
│ │ ├── upernet_overlock_base_ade20k_8xb2.py
│ │ ├── upernet_overlock_small_ade20k_8xb2.py
│ │ └── upernet_overlock_tiny_ade20k_8xb2.py
│ ├── mmseg_custom/
│ │ ├── __init__.py
│ │ └── align_resize.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── overlock.py
│ ├── readme.md
│ ├── scripts/
│ │ ├── dist_test.sh
│ │ └── dist_train.sh
│ ├── test.py
│ └── train.py
├── train.py
└── validate.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE.md
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# [[CVPR 2025 Oral] OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels](https://arxiv.org/abs/2502.20087)
This is an official PyTorch implementation of "[**OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels**](https://arxiv.org/abs/2502.20087)".
# Introduction
Top-down attention plays a crucial role in the human vision system, wherein the brain initially obtains a rough overview of a scene to discover salient cues (i.e., overview first), followed by a more careful finer-grained examination (i.e., look closely next). However, modern ConvNets remain confined to a pyramid structure that successively downsamples the feature map for receptive field expansion, neglecting this crucial biomimetic principle. We present OverLoCK, the first pure ConvNet backbone architecture that explicitly incorporates a top-down attention mechanism. Unlike pyramid backbone networks, our design features a branched architecture with three synergistic sub-networks: 1) a Base-Net that encodes low/mid-level features; 2) a lightweight Overview-Net that generates dynamic top-down attention through coarse global context modeling (i.e., overview first); and 3) a robust Focus-Net that performs finer-grained perception guided by top-down attention (i.e., look closely next). To fully unleash the power of top-down attention, we further propose a novel context-mixing dynamic convolution (ContMix) that effectively models long-range dependencies while preserving inherent local inductive biases even when the input resolution increases, addressing critical limitations in existing convolutions. Our OverLoCK exhibits a notable performance improvement over existing methods.
# News
- **Dec. 25, 2025**: **To improve inference speed and reduce memory consumption**, we provide **reparameterized versions of the OverLoCK models with pre-trained weights**. These variants achieve **identical performance to their original counterparts on ImageNet-1K evaluation**. However, if you further fine-tune these reparameterized models, they may yield slightly lower accuracy compared to the original versions. Please choose the model variant during fine-tuning based on memory and accuracy requirements on your side ([More Details](https://github.com/LMMMEng/OverLoCK/blob/81dd7b216e7aa66ff5a95b07021f299dc2d4d14b/models/overlock.py#L941C13-L941C14)).
- **May. 16, 2025**: A **plug-and-play implementation of the [ContMix Block](models/contmix.py)** is now available.
# Image Classification
## 1. Requirements
We highly suggest using our provided dependencies to ensure reproducibility:
```
# Environments:
cuda==12.1
python==3.10
# Dependencies:
pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
pip install natten==0.17.1+torch230cu121 -f https://shi-labs.com/natten/wheels/
pip install timm==0.6.12
pip install mmengine==0.2.0
```
>💡 To accelerate training and inference, we utilize the efficient large-kernel convolution proposed in [RepLKNet](https://github.com/DingXiaoH/RepLKNet-pytorch#use-our-efficient-large-kernel-convolution-with-pytorch). Please follow this [**guideline**](https://github.com/VITA-Group/SLaK#installation) to install the ``depthwise_conv2d_implicit_gemm`` function.
>
>💡 If you encounter network issues during the installation of ``natten``, please download this [**package**](https://github.com/LMMMEng/OverLoCK/releases/download/v1/natten-0.17.1+torch230cu121-cp310-cp310-linux_x86_64.whl) and install it locally.
## 2. Data Preparation
Prepare [ImageNet](https://image-net.org/) with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).
```
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
```
## 3. Main Results on ImageNet-1K with Pretrained Models
| Models | Input Size | FLOPs (G) | Params (M) | Top-1 (%) | Download |
|:-----------:|:----------:|:---------:|:----------:|:----------:|:----------:|
| OverLoCK-XT | 224x224 | 2.6 | 16 | 82.7 | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth) |
| OverLoCK-T | 224x224 | 5.5 | 33 | 84.2 | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth) |
| OverLoCK-S | 224x224 | 9.7 | 56 | 84.8 | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth) |
| OverLoCK-B | 224x224 | 16.7 | 95 | 85.1 | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth) |
## 4. Train
To train ```OverLoCK``` models on ImageNet-1K with 8 gpus (single node), run:
```
bash scripts/train_xt_model.sh # train OverLoCK-XT
bash scripts/train_t_model.sh # train OverLoCK-T
bash scripts/train_s_model.sh # train OverLoCK-S
bash scripts/train_b_model.sh # train OverLoCK-B
```
> 💡If you encounter NaN loss, please delete ``--native-amp`` to disable AMP training and resume the checkpoint before the NaN loss occurred.
>
> 💡If your **GPU memory** is insufficient during training, you can enable gradient checkpointing by adding the following arguments: ``--grad-checkpoint --ckpt-stg 4 0 0 0``. If you're still experiencing memory issues, you can increase these values, but be aware that this may slow down training speed.
## 5. Validation
To evaluate ```OverLoCK``` on ImageNet-1K, run:
```
MODEL=overlock_xt # overlock_{xt, t, s, b}
python3 validate.py \
/path/to/imagenet \
--model $MODEL -b 128 \
--pretrained # or --checkpoint /path/to/checkpoint
```
>💡 To accelerate inference speed, OverLoCK utilizes [**Structural Re-parameterization**](https://github.com/AILab-CVC/UniRepLKNet/tree/main). Please refer to [**here**](https://github.com/LMMMEng/OverLoCK/blob/540bf6ed9cca99eab78fc8ab935b71f2a4aa2a2c/models/overlock.py#L945) for a simple usage instruction.
# Citation
If you find this project useful for your research, please consider citing:
```
@inproceedings{lou2025overlock,
title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},
author={Lou, Meng and Yu, Yizhou},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
pages={128--138},
year={2025}
}
```
# Dense Predictions
[Object Detection](detection)
[Semantic Segmentation](segmentation)
# Acknowledgment
Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.
> [timm](https://github.com/rwightman/pytorch-image-models), [natten](https://github.com/SHI-Labs/NATTEN), [unireplknet](https://github.com/AILab-CVC/UniRepLKNet), [mmcv](https://github.com/open-mmlab/mmcv), [mmdet](https://github.com/open-mmlab/mmdetection), [mmseg](https://github.com/open-mmlab/mmsegmentation)
# Contact
If you have any questions, please feel free to [create issues❓](https://github.com/LMMMEng/OverLoCK/issues) or [contact me 📧](lmzmm.0921@gmail.com).
================================================
FILE: detection/configs/_base_/datasets/coco_detection.py
================================================
# dataset settings
dataset_type = 'CocoDataset'
data_root = '/grp01/cs_yzyu/dataset/coco2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox', classwise=True)
================================================
FILE: detection/configs/_base_/datasets/coco_instance.py
================================================
# dataset settings
dataset_type = 'CocoDataset'
data_root = '/grp01/cs_yzyu/dataset/coco2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'], classwise=True)
================================================
FILE: detection/configs/_base_/default_runtime.py
================================================
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
================================================
FILE: detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='CascadeRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn_crowdhuman.py
================================================
# model settings
model = dict(
type='CascadeRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: detection/configs/_base_/models/cascade_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='CascadeRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
]),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)
]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
================================================
FILE: detection/configs/_base_/models/fast_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='FastRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
================================================
FILE: detection/configs/_base_/models/faster_rcnn_r50_caffe_c4.py
================================================
# model settings
norm_cfg = dict(type='BN', requires_grad=False)
model = dict(
type='FasterRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
norm_cfg=norm_cfg,
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_generator=dict(
type='AnchorGenerator',
scales=[2, 4, 8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[16]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
norm_cfg=norm_cfg,
norm_eval=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=12000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=6000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
================================================
FILE: detection/configs/_base_/models/faster_rcnn_r50_caffe_dc5.py
================================================
# model settings
norm_cfg = dict(type='BN', requires_grad=False)
model = dict(
type='FasterRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
strides=(1, 2, 2, 1),
dilations=(1, 1, 1, 2),
out_indices=(3, ),
frozen_stages=1,
norm_cfg=norm_cfg,
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
rpn_head=dict(
type='RPNHead',
in_channels=2048,
feat_channels=2048,
anchor_generator=dict(
type='AnchorGenerator',
scales=[2, 4, 8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[16]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=2048,
featmap_strides=[16]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=2048,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=12000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms=dict(type='nms', iou_threshold=0.7),
nms_pre=6000,
max_per_img=1000,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
================================================
FILE: detection/configs/_base_/models/faster_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='FasterRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
))
================================================
FILE: detection/configs/_base_/models/mask_rcnn_convnext_fpn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# model settings
model = dict(
type='MaskRCNN',
pretrained=None,
backbone=dict(
type='ConvNeXt',
in_chans=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.2,
layer_scale_init_value=1e-6,
out_indices=[0, 1, 2, 3],
),
neck=dict(
type='FPN',
in_channels=[128, 256, 512, 1024],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: detection/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
================================================
# model settings
norm_cfg = dict(type='BN', requires_grad=False)
model = dict(
type='MaskRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
norm_cfg=norm_cfg,
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_generator=dict(
type='AnchorGenerator',
scales=[2, 4, 8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[16]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
norm_cfg=norm_cfg,
norm_eval=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=None,
mask_head=dict(
type='FCNMaskHead',
num_convs=0,
in_channels=2048,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=12000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=6000,
nms=dict(type='nms', iou_threshold=0.7),
max_per_img=1000,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: detection/configs/_base_/models/mask_rcnn_r50_fpn.py
================================================
# model settings
model = dict(
type='MaskRCNN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
================================================
FILE: detection/configs/_base_/models/retinanet_r50_fpn.py
================================================
# model settings
model = dict(
type='RetinaNet',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
================================================
FILE: detection/configs/_base_/models/rpn_r50_caffe_c4.py
================================================
# model settings
model = dict(
type='RPN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
neck=None,
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_generator=dict(
type='AnchorGenerator',
scales=[2, 4, 8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[16]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=12000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0)))
================================================
FILE: detection/configs/_base_/models/rpn_r50_fpn.py
================================================
# model settings
model = dict(
type='RPN',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0)))
================================================
FILE: detection/configs/_base_/models/ssd300.py
================================================
# model settings
input_size = 300
model = dict(
type='SingleStageDetector',
backbone=dict(
type='SSDVGG',
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://vgg16_caffe')),
neck=dict(
type='SSDNeck',
in_channels=(512, 1024),
out_channels=(512, 1024, 512, 256, 256, 256),
level_strides=(2, 2, 1, 1),
level_paddings=(1, 1, 0, 0),
l2_norm_scale=20),
bbox_head=dict(
type='SSDHead',
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=80,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
input_size=input_size,
basesize_ratio_range=(0.15, 0.9),
strides=[8, 16, 32, 64, 100, 300],
ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2])),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.,
ignore_iof_thr=-1,
gt_max_assign_all=False),
smoothl1_beta=1.,
allowed_border=-1,
pos_weight=-1,
neg_pos_ratio=3,
debug=False),
test_cfg=dict(
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.45),
min_bbox_size=0,
score_thr=0.02,
max_per_img=200))
cudnn_benchmark = True
================================================
FILE: detection/configs/_base_/schedules/schedule_1x.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
================================================
FILE: detection/configs/_base_/schedules/schedule_3x.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_1x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
dims = [80, 160, 528, 720]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_b',
pretrained=True,
drop_path_rate=0.6
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
###########################################################################################################
# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py
# We follow uniformer's optimizer and lr schedule
optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_3x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_3x.py',
'../_base_/default_runtime.py'
]
dims = [80, 160, 528, 720]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_b',
pretrained=True,
drop_path_rate=0.6
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
# We use 8 GPUs to train this model so that the total batch size was 16
data = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
# # AMP (faster but may meet nan loss) ->
# fp16 = dict(loss_scale='dynamic')
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_1x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
dims = [64, 128, 448, 640]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_s',
pretrained=True,
drop_path_rate=0.4
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
###########################################################################################################
# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py
# We follow uniformer's optimizer and lr schedule
optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_3x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_3x.py',
'../_base_/default_runtime.py'
]
dims = [64, 128, 448, 640]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_s',
pretrained=True,
drop_path_rate=0.4
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
# We use 8 GPUs to train this model so that the total batch size was 16
data = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
# # AMP (faster but may meet nan loss) ->
# fp16 = dict(loss_scale='dynamic')
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
dims = [64, 128, 384, 640]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_t',
pretrained=True,
drop_path_rate=0.2
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
###########################################################################################################
# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py
# We follow uniformer's optimizer and lr schedule
optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
================================================
FILE: detection/configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_3x_coco.py
================================================
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_3x.py',
'../_base_/default_runtime.py'
]
dims = [64, 128, 384, 640]
model = dict(
backbone=dict(
_delete_=True,
type='overlock_t',
pretrained=True,
drop_path_rate=0.2
),
neck=dict(
type='FPN',
in_channels=dims,
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
# We use 8 GPUs to train this model so that the total batch size was 16
data = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
evaluation = dict(save_best='auto')
checkpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)
# # AMP (faster but may meet nan loss) ->
# fp16 = dict(loss_scale='dynamic')
================================================
FILE: detection/models/__init__.py
================================================
from .overlock import *
================================================
FILE: detection/models/overlock.py
================================================
'''
This is an official implementation of OverLoCK model proposed in the paper:
https://arxiv.org/abs/2502.20087
'''
import torch
import timm
import torch.distributed
import torch.nn.functional as F
from torch import nn
from einops import rearrange, einsum
from natten.functional import na2d_av
from torch.utils.checkpoint import checkpoint
from timm.models.layers import DropPath, to_2tuple
from timm.models.registry import register_model
from mmdet.models.builder import MODELS
from mmdet.utils import get_root_logger
try:
from mmcv.runner import load_checkpoint
except:
from mmengine.runner import load_checkpoint
def get_conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
attempt_use_lk_impl=True):
kernel_size = to_2tuple(kernel_size)
if padding is None:
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
padding = to_2tuple(padding)
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
if attempt_use_lk_impl and need_large_impl:
print('---------------- trying to import iGEMM implementation for large-kernel conv')
try:
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
print('---------------- found iGEMM implementation ')
except:
DepthWiseConv2dImplicitGEMM = None
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
and out_channels == groups and stride == 1 and dilation == 1:
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def get_bn(dim, use_sync_bn=False):
if use_sync_bn:
return nn.SyncBatchNorm(dim)
else:
return nn.BatchNorm2d(dim)
def fuse_bn(conv, bn):
conv_bias = 0 if conv.bias is None else conv.bias
std = (bn.running_var + bn.eps).sqrt()
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
def convert_dilated_to_nondilated(kernel, dilate_rate):
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
if kernel.size(1) == 1:
# This is a DW kernel
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
return dilated
else:
# This is a dense or group-wise (but not DW) kernel
slices = []
for i in range(kernel.size(1)):
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
slices.append(dilated)
return torch.cat(slices, dim=1)
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
large_k = large_kernel.size(2)
dilated_k = dilated_kernel.size(2)
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
return merged_kernel
def stem(in_chans=3, embed_dim=96):
return nn.Sequential(
nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim)
)
def downsample(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_dim),
)
class SEModule(nn.Module):
def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):
super().__init__()
inner_dim = max(16, dim // red)
self.proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(dim, inner_dim, kernel_size=1),
inner_act(),
nn.Conv2d(inner_dim, dim, kernel_size=1),
out_act(),
)
def forward(self, x):
x = x * self.proj(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,
requires_grad=True)
self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)
def forward(self, x):
x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])
return x
class LayerNorm2d(nn.LayerNorm):
def __init__(self, dim):
super().__init__(normalized_shape=dim, eps=1e-6)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c')
x = super().forward(x)
x = rearrange(x, 'b h w c -> b c h w')
return x.contiguous()
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
We assume the inputs to this layer are (N, C, H, W)
"""
def __init__(self, dim, use_bias=True):
super().__init__()
self.use_bias = use_bias
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
if self.use_bias:
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
if self.use_bias:
return (self.gamma * Nx + 1) * x + self.beta
else:
return (self.gamma * Nx + 1) * x
class DilatedReparamBlock(nn.Module):
"""
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
We assume the inputs to this block are (N, C, H, W)
"""
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
super().__init__()
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
attempt_use_lk_impl=attempt_use_lk_impl)
self.attempt_use_lk_impl = attempt_use_lk_impl
# Default settings. We did not tune them carefully. Different settings may work better.
if kernel_size == 19:
self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]
self.dilates = [1, 1, 1, 2, 4, 5, 7]
elif kernel_size == 17:
self.kernel_sizes = [5, 7, 9, 3, 3, 3]
self.dilates = [1, 1, 2, 4, 5, 7]
elif kernel_size == 15:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 5, 7]
elif kernel_size == 13:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 11:
self.kernel_sizes = [5, 7, 5, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 9:
self.kernel_sizes = [5, 7, 5, 3, 3]
self.dilates = [1, 1, 2, 3, 4]
elif kernel_size == 7:
self.kernel_sizes = [5, 3, 3, 3]
self.dilates = [1, 1, 2, 3]
elif kernel_size == 5:
self.kernel_sizes = [3, 3]
self.dilates = [1, 2]
else:
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
if not deploy:
self.origin_bn = get_bn(channels, use_sync_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
bias=False))
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
def forward(self, x):
if not hasattr(self, 'origin_bn'): # deploy mode
return self.lk_origin(x)
out = self.origin_bn(self.lk_origin(x))
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
out = out + bn(conv(x))
return out
def merge_dilated_branches(self):
if hasattr(self, 'origin_bn'):
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
branch_k, branch_b = fuse_bn(conv, bn)
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
origin_b += branch_b
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
attempt_use_lk_impl=self.attempt_use_lk_impl)
merged_conv.weight.data = origin_k
merged_conv.bias.data = origin_b
self.lk_origin = merged_conv
self.__delattr__('origin_bn')
for k, r in zip(self.kernel_sizes, self.dilates):
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
class CTXDownsample(nn.Module):
def __init__(self, dim, h_dim):
super().__init__()
self.x_proj = nn.Sequential(
nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim)
)
self.h_proj = nn.Sequential(
nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim//4)
)
def forward(self, x, ctx):
x = self.x_proj(x)
ctx = self.h_proj(ctx)
return (x, ctx)
class ResDWConv(nn.Conv2d):
'''
Depthwise convolution with residual connection
'''
def __init__(self, dim, kernel_size=3):
super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
def forward(self, x):
x = x + super().forward(x)
return x
class RepConvBlock(nn.Module):
def __init__(self,
dim=64,
kernel_size=7,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
use_gemm=False,
deploy=False,
use_checkpoint=False):
super().__init__()
self.res_scale = res_scale
self.use_checkpoint = use_checkpoint
mlp_dim = int(dim*mlp_ratio)
self.dwconv = ResDWConv(dim, kernel_size=3)
self.proj = nn.Sequential(
norm_layer(dim),
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
SEModule(dim),
nn.Conv2d(dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, dim, kernel_size=1),
DropPath(drop_path) if drop_path > 0 else nn.Identity(),
)
self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
def forward_features(self, x):
x = self.dwconv(x)
if self.res_scale:
x = self.ls(x) + self.proj(x)
else:
drop_path = self.proj[-1]
x = x + drop_path(self.ls(self.proj[:-1](x)))
return x
def forward(self, x):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self.forward_features, x, use_reentrant=False)
else:
x = self.forward_features(x)
return x
class DynamicConvBlock(nn.Module):
def __init__(self,
dim=64,
ctx_dim=32,
kernel_size=7,
smk_size=5,
num_heads=2,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
is_first=False,
is_last=False,
use_gemm=False,
deploy=False,
use_checkpoint=False,
**kwargs):
super().__init__()
ctx_dim = ctx_dim // 4
out_dim = dim + ctx_dim
mlp_dim = int(dim*mlp_ratio)
self.kernel_size = kernel_size
self.res_scale = res_scale
self.use_gemm = use_gemm
self.smk_size = smk_size
self.num_heads = num_heads * 2
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.is_first = is_first
self.is_last = is_last
self.use_checkpoint = use_checkpoint
if not is_first:
self.x_scale = LayerScale(ctx_dim, init_value=1)
self.h_scale = LayerScale(ctx_dim, init_value=1)
self.dwconv1 = ResDWConv(out_dim, kernel_size=3)
self.norm1 = norm_layer(out_dim)
self.fusion = nn.Sequential(
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),
nn.BatchNorm2d(out_dim),
nn.GELU(),
nn.Conv2d(out_dim, dim, kernel_size=1),
GRN(dim),
)
self.weight_query = nn.Sequential(
nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_key = nn.Sequential(
nn.AdaptiveAvgPool2d(7),
nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)
self.dyconv_proj = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
)
self.lepe = nn.Sequential(
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
)
self.se_layer = SEModule(dim)
self.gate = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
nn.SiLU(),
)
self.proj = nn.Sequential(
nn.BatchNorm2d(dim),
nn.Conv2d(dim, out_dim, kernel_size=1),
)
self.dwconv2 = ResDWConv(out_dim, kernel_size=3)
self.norm2 = norm_layer(out_dim)
self.mlp = nn.Sequential(
nn.Conv2d(out_dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, out_dim, kernel_size=1),
)
self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.get_rpb()
def get_rpb(self):
self.rpb_size1 = 2 * self.smk_size - 1
self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))
self.rpb_size2 = 2 * self.kernel_size - 1
self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))
nn.init.zeros_(self.rpb1)
nn.init.zeros_(self.rpb2)
@torch.no_grad()
def generate_idx(self, kernel_size):
rpb_size = 2 * kernel_size - 1
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)
return (idx_h, idx_w, idx_k)
def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):
"""
RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3
"""
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size//2] = height - (kernel_size-1)
num_repeat_w[kernel_size//2] = width - (kernel_size-1)
bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + idx_k
bias_idx = bias_idx.reshape(-1, int(kernel_size**2))
bias_idx = torch.flip(bias_idx, [0])
rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]
rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))
return attn + rpb
def _forward_inner(self, x, h_x, h_r):
input_resoltion = x.shape[2:]
B, C, H, W = x.shape
B, C_h, H_h, W_h = h_x.shape
if not self.is_first:
h_x = self.x_scale(h_x) + self.h_scale(h_r)
x_f = torch.cat([x, h_x], dim=1)
x_f = self.dwconv1(x_f)
identity = x_f
x_f = self.norm1(x_f)
x = self.fusion(x_f)
gate = self.gate(x)
lepe = self.lepe(x)
is_pad = False
if min(H, W) < self.kernel_size:
is_pad = True
if H < W:
size = (self.kernel_size, int(self.kernel_size / H * W))
else:
size = (int(self.kernel_size / W * H), self.kernel_size)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)
H, W = size
query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)
query = self.weight_query(query) * self.scale
key = self.weight_key(key)
query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
weight = einsum(query, key, 'b g c n, b g c l -> b g n l')
weight = rearrange(weight, 'b g n l -> b l g n').contiguous()
weight = self.weight_proj(weight)
weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)
attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)
rpb1_idx = self.generate_idx(self.smk_size)
rpb2_idx = self.generate_idx(self.kernel_size)
attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)
attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)
attn1 = torch.softmax(attn1, dim=-1)
attn2 = torch.softmax(attn2, dim=-1)
value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)
x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)
x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)
x = torch.cat([x1, x2], dim=1)
x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)
if is_pad:
x = F.adaptive_avg_pool2d(x, input_resoltion)
x = self.dyconv_proj(x)
x = x + lepe
x = self.se_layer(x)
x = gate * x
x = self.proj(x)
if self.res_scale:
x = self.ls1(identity) + self.drop_path(x)
else:
x = identity + self.drop_path(self.ls1(x))
x = self.dwconv2(x)
if self.res_scale:
x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
if self.is_last:
return (x, None)
else:
l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)
return (l_x, h_x)
def forward(self, x, h_x, h_r):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)
else:
x = self._forward_inner(x, h_x, h_r)
return x
class OverLoCK(nn.Module):
'''
An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels
https://arxiv.org/abs/2502.20087
'''
def __init__(self,
depth=[2, 2, 2, 2],
sub_depth=[4, 2],
in_chans=3,
embed_dim=[96, 192, 384, 768],
kernel_size=[7, 7, 7, 7],
mlp_ratio=[4, 4, 4, 4],
sub_mlp_ratio=[4, 4],
sub_num_heads=[4, 8],
ls_init_value=[None, None, 1, 1],
res_scale=True,
smk_size=5,
deploy=False,
use_gemm=True,
use_ds=True,
drop_rate=0,
drop_path_rate=0,
norm_layer=LayerNorm2d,
projection=1024,
num_classes=1000,
use_checkpoint=[0, 0, 0, 0],
):
super().__init__()
fusion_dim = embed_dim[-1] + embed_dim[-1]//4
# self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed1 = stem(in_chans, embed_dim[0])
self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])
self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])
self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])
self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)
self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]
self.blocks1 = nn.ModuleList()
self.blocks2 = nn.ModuleList()
self.blocks3 = nn.ModuleList()
self.blocks4 = nn.ModuleList()
self.sub_blocks3 = nn.ModuleList()
self.sub_blocks4 = nn.ModuleList()
for i in range(depth[0]):
self.blocks1.append(
RepConvBlock(
dim=embed_dim[0],
kernel_size=kernel_size[0],
mlp_ratio=mlp_ratio[0],
ls_init_value=ls_init_value[0],
res_scale=res_scale,
drop_path=dpr[i],
norm_layer=norm_layer,
use_gemm=use_gemm,
deploy=deploy,
use_checkpoint=(i 0 else nn.Identity()
)
# Main Cls Head
self.head = nn.Sequential(
nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),
nn.BatchNorm2d(projection),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()
)
self.extra_norm = nn.ModuleList()
for idx in range(4):
dim = embed_dim[idx]
if idx >= 2:
dim = dim + embed_dim[-1]//4
self.extra_norm.append(norm_layer(dim))
self.extra_norm.append(norm_layer(embed_dim[-1]))
del self.aux_head
del self.head
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def _convert_sync_batchnorm(self):
if torch.distributed.is_initialized():
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
def forward_pre_features(self, x):
outs = []
x = self.patch_embed1(x)
for blk in self.blocks1:
x = blk(x)
outs.append(self.extra_norm[0](x))
x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
outs.append(self.extra_norm[1](x))
return outs
def forward_base_features(self, x):
x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
ctx = self.patch_embed4(x)
for blk in self.blocks4:
ctx = blk(ctx)
return (x, ctx)
def forward_sub_features(self, x, ctx):
outs = []
ctx_cls = ctx
ctx_ori = self.high_level_proj(ctx)
ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)
for idx, blk in enumerate(self.sub_blocks3):
if idx == 0:
ctx = ctx_up
x, ctx = blk(x, ctx, ctx_up)
outs.append(self.extra_norm[2](torch.cat([x, ctx], dim=1)))
x, ctx = self.patch_embedx(x, ctx)
for idx, blk in enumerate(self.sub_blocks4):
x, ctx = blk(x, ctx, ctx_ori)
ctx = self.extra_norm[-1](ctx_cls)
x = self.extra_norm[3](x) + self.h_proj(ctx)
outs.append(x)
return outs
def forward_features(self, x):
x0, x1 = self.forward_pre_features(x)
x, ctx = self.forward_base_features(x1)
x2, x3 = self.forward_sub_features(x, ctx)
return (x0, x1, x2, x3)
def forward(self, x):
x = self.forward_features(x)
return x
@MODELS.register_module()
def overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[2, 2, 3, 2],
sub_depth=[6, 2],
embed_dim=[56, 112, 256, 336],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 6],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[4, 4, 6, 2],
sub_depth=[12, 2],
embed_dim=[64, 128, 256, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 8],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[6, 6, 8, 3],
sub_depth=[16, 3],
embed_dim=[64, 128, 320, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[8, 16],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[8, 8, 10, 4],
sub_depth=[20, 4],
embed_dim=[80, 160, 384, 576],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[6, 9],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
================================================
FILE: detection/readme.md
================================================
# Applying OverLoCK to Object Detection and Instance Segmentation
## 1. Requirements
```
pip install mmcv-full==1.7.2 --no-cache-dir
pip install mmdet==2.28.2 --no-cache-dir
```
💡 To enable torch>=2.1.0 to support mmcv 1.7.2, you need to make the following changes:
> 1️⃣ https://goo.su/XhU5vWr
> 2️⃣ https://goo.su/ogm4yO
## 2. Data Preparation
Prepare COCO 2017 according to the [guidelines](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/1_exist_data_model.md).
## 3. Main Results on COCO using Mask R-CNN framework
| Backbone | Pretrain | Schedule | AP_b | AP_m | Config | Download |
|:-------------:|:-----------:|:--------:|--------|:-------:|:------:|:----------:|
| OverLoCK-T | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth)| 1x | 48.3 |43.3 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_tiny_coco.pth) |
| | | 3x |49.6 |43.9 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_3x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_tiny_coco.pth) |
| OverLoCK-S | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth)| 1x |49.4 |44.0 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_1x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_small_coco.pth) |
| | | 3x |51.0 |45.0 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_3x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_small_coco.pth) |
| OverLoCK-B | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth) | 1x |49.9 |44.4 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_1x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_base_coco.pth) |
| | | 3x |51.4 |45.3 |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_3x_coco.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_base_coco.pth) |
## 4. Train
To train ``OverLoCK-T + Mask R-CNN 1x`` model on COCO dataset with 8 GPUs (single node), run:
```
NUM_GPUS=8
CONFIG=configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py
bash scripts/dist_train.sh $CONFIG $NUM_GPUS
```
## 5. Validation
To evaluate ``OverLoCK-T + Mask R-CNN 1x`` model on COCO dataset, run:
```
NUM_GPUS=8
CKPT=path-to-checkpoint.pth
CONFIG=configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py
bash scripts/dist_test.sh $CONFIG $CKPT $NUM_GPUS --eval bbox segm
```
## Citation
If you find this project useful for your research, please consider citing:
```
@inproceedings{lou2025overlock,
title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},
author={Lou, Meng and Yu, Yizhou},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
pages={128--138},
year={2025}
}
```
================================================
FILE: detection/scripts/dist_test.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=$((RANDOM+10000))
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
torchrun --nproc_per_node=$GPUS --master_port=$PORT test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
================================================
FILE: detection/scripts/dist_train.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=$((RANDOM+10000))
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
torchrun --nproc_per_node=$GPUS --master_port=$PORT train.py $CONFIG --launcher pytorch ${@:3}
================================================
FILE: detection/test.py
================================================
import argparse
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.models import build_detector
import models
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument('--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument('--show-dir',
help='directory where painted images will be saved')
parser.add_argument('--show-score-thr',
type=float,
default=0.3,
help='score threshold (default: 0.3)')
parser.add_argument('--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')
parser.add_argument('--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.eval_options:
raise ValueError(
'--options and --eval-options cannot be both '
'specified, --options is deprecated in favor of --eval-options')
if args.options:
warnings.warn('--options is deprecated in favor of --eval-options')
args.eval_options = args.options
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if cfg.model.get('neck'):
if isinstance(cfg.model.neck, list):
for neck_cfg in cfg.model.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif cfg.model.neck.get('rfp_backbone'):
if cfg.model.neck.rfp_backbone.get('pretrained'):
cfg.model.neck.rfp_backbone.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
if samples_per_gpu > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
if len(cfg.gpu_ids) > 1:
warnings.warn(
f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
'non-distribute testing time.')
cfg.gpu_ids = cfg.gpu_ids[0:1]
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
rank, _ = get_dist_info()
# allows not to create
if args.work_dir is not None and rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
if not distributed:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule', 'dynamic_intervals'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
metric = dataset.evaluate(outputs, **eval_kwargs)
print(metric)
metric_dict = dict(config=args.config, metric=metric)
if args.work_dir is not None and rank == 0:
mmcv.dump(metric_dict, json_file)
if __name__ == '__main__':
main()
================================================
FILE: detection/train.py
================================================
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet import __version__
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import (collect_env, get_device, get_root_logger,
replace_cfg_vals, setup_multi_processes,
update_data_root)
import models
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--resume-from',
help='the checkpoint file to resume from')
parser.add_argument('--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument('--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--drop-path',
default=-1,
type=float,
help='drop-path-rate of the backbone network')
parser.add_argument(
'--freeze-bn',
action='store_true',
default=False,
help='freeze the BN layer of the backbone model during training')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='mixed precision training')
parser.add_argument('--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--auto-scale-lr',
action='store_true',
help='enable automatically scaling LR.')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)
# update data root according to MMDET_DATASETS
update_data_root(cfg)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
'enable' in cfg.auto_scale_lr and \
'base_batch_size' in cfg.auto_scale_lr:
cfg.auto_scale_lr.enable = True
else:
warnings.warn('Can not find "auto_scale_lr" or '
'"auto_scale_lr.enable" or '
'"auto_scale_lr.base_batch_size" in your'
' configuration file. Please update all the '
'configuration files to mmdet >= 2.24.1.')
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
if args.freeze_bn:
try:
cfg.model.backbone.freeze_bn = True
except:
logger.info('freeze_bn is not defined in the config file')
if args.drop_path >= 0:
try:
cfg.model.backbone.drop_path_rate = args.drop_path
except:
logger.info('drop_path is not defined in the config file')
if args.amp:
loss_scale = 'dynamic'
if cfg.get('fp16', None) is None:
cfg.fp16 = dict(loss_scale=loss_scale)
# warnings.warn('fp16 is not defined in the config file')
else:
# cfg.fp16.enabled = True
# cfg.fp16.loss_scale = loss_scale
warnings.warn('fp16 has been defined in the config file')
# cfg.optimizer_config.type = 'Fp16OptimizerHook'
# cfg.optimizer_config.loss_scale = loss_scale
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
cfg.device = get_device()
# set random seeds
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
# model.init_weights()
logger.info(model)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(mmdet_version=__version__ +
get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
train_detector(model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
================================================
FILE: models/__init__.py
================================================
from .overlock import overlock_xt, overlock_t, overlock_s, overlock_b
================================================
FILE: models/contmix.py
================================================
'''
This is a plug-and-play implementation of ContMix block in the paper:
https://arxiv.org/abs/2502.20087
'''
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, einsum
from timm.models.layers import DropPath, to_2tuple
from torch.utils.checkpoint import checkpoint
try:
from natten.functional import na2d_av
has_natten = True
except:
has_natten = False
warnings.warn("The efficiency may be reduced since 'natten' is not installed."
" It is recommended to install natten for better performance.")
def get_conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
attempt_use_lk_impl=True):
kernel_size = to_2tuple(kernel_size)
if padding is None:
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
padding = to_2tuple(padding)
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
if attempt_use_lk_impl and need_large_impl:
print('---------------- trying to import iGEMM implementation for large-kernel conv')
try:
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
print('---------------- found iGEMM implementation ')
except:
DepthWiseConv2dImplicitGEMM = None
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
and out_channels == groups and stride == 1 and dilation == 1:
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def get_bn(dim, use_sync_bn=False):
if use_sync_bn:
return nn.SyncBatchNorm(dim)
else:
return nn.BatchNorm2d(dim)
def fuse_bn(conv, bn):
conv_bias = 0 if conv.bias is None else conv.bias
std = (bn.running_var + bn.eps).sqrt()
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
def convert_dilated_to_nondilated(kernel, dilate_rate):
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
if kernel.size(1) == 1:
# This is a DW kernel
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
return dilated
else:
# This is a dense or group-wise (but not DW) kernel
slices = []
for i in range(kernel.size(1)):
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
slices.append(dilated)
return torch.cat(slices, dim=1)
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
large_k = large_kernel.size(2)
dilated_k = dilated_kernel.size(2)
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
return merged_kernel
class SEModule(nn.Module):
def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):
super().__init__()
inner_dim = max(16, dim // red)
self.proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(dim, inner_dim, kernel_size=1),
inner_act(),
nn.Conv2d(inner_dim, dim, kernel_size=1),
out_act(),
)
def forward(self, x):
x = x * self.proj(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,
requires_grad=True)
self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)
def forward(self, x):
x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])
return x
class LayerNorm2d(nn.LayerNorm):
def __init__(self, dim):
super().__init__(normalized_shape=dim, eps=1e-6)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c')
x = super().forward(x)
x = rearrange(x, 'b h w c -> b c h w')
return x.contiguous()
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
We assume the inputs to this layer are (N, C, H, W)
"""
def __init__(self, dim, use_bias=True):
super().__init__()
self.use_bias = use_bias
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
if self.use_bias:
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
if self.use_bias:
return (self.gamma * Nx + 1) * x + self.beta
else:
return (self.gamma * Nx + 1) * x
class DilatedReparamBlock(nn.Module):
"""
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
We assume the inputs to this block are (N, C, H, W)
"""
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
super().__init__()
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
attempt_use_lk_impl=attempt_use_lk_impl)
self.attempt_use_lk_impl = attempt_use_lk_impl
# Default settings. We did not tune them carefully. Different settings may work better.
if kernel_size == 19:
self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]
self.dilates = [1, 1, 1, 2, 4, 5, 7]
elif kernel_size == 17:
self.kernel_sizes = [5, 7, 9, 3, 3, 3]
self.dilates = [1, 1, 2, 4, 5, 7]
elif kernel_size == 15:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 5, 7]
elif kernel_size == 13:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 11:
self.kernel_sizes = [5, 7, 5, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 9:
self.kernel_sizes = [5, 7, 5, 3, 3]
self.dilates = [1, 1, 2, 3, 4]
elif kernel_size == 7:
self.kernel_sizes = [5, 3, 3, 3]
self.dilates = [1, 1, 2, 3]
elif kernel_size == 5:
self.kernel_sizes = [3, 3]
self.dilates = [1, 2]
else:
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
if not deploy:
self.origin_bn = get_bn(channels, use_sync_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
bias=False))
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
def forward(self, x):
if not hasattr(self, 'origin_bn'): # deploy mode
return self.lk_origin(x)
out = self.origin_bn(self.lk_origin(x))
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
out = out + bn(conv(x))
return out
def merge_dilated_branches(self):
if hasattr(self, 'origin_bn'):
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
branch_k, branch_b = fuse_bn(conv, bn)
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
origin_b += branch_b
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
attempt_use_lk_impl=self.attempt_use_lk_impl)
merged_conv.weight.data = origin_k
merged_conv.bias.data = origin_b
self.lk_origin = merged_conv
self.__delattr__('origin_bn')
for k, r in zip(self.kernel_sizes, self.dilates):
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
class ResDWConv(nn.Conv2d):
'''
Depthwise conv with residual connection
'''
def __init__(self, dim, kernel_size=3):
super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
def forward(self, x):
x = x + super().forward(x)
return x
class ContMixBlock(nn.Module):
'''
A plug-and-play implementation of ContMix module with FFN layer
Paper: https://arxiv.org/abs/2502.20087
'''
def __init__(self,
dim=64,
kernel_size=7,
smk_size=5,
num_heads=2,
mlp_ratio=4,
res_scale=False,
ls_init_value=None,
drop_path=0,
norm_layer=LayerNorm2d,
use_gemm=False,
deploy=False,
use_checkpoint=False,
**kwargs):
super().__init__()
'''
Args:
kernel_size: kernel size of the main ContMix branch, default is 7
smk_size: kernel size of the secondary ContMix branch, default is 5
num_heads: number of dynamic kernel heads, default is 2
mlp_ratio: ratio of mlp hidden dim to embedding dim, default is 4
res_scale: whether to use residual layer scale, default is False
ls_init_value: layer scale init value, default is None
drop_path: drop path rate, default is 0
norm_layer: normalization layer, default is LayerNorm2d
use_gemm: whether to use iGEMM implementation for large kernel conv, default is False
deploy: whether to use deploy mode, default is False
use_checkpoint: whether to use grad checkpointing, default is False
**kwargs: other arguments
'''
mlp_dim = int(dim*mlp_ratio)
self.kernel_size = kernel_size
self.res_scale = res_scale
self.use_gemm = use_gemm
self.smk_size = smk_size
self.num_heads = num_heads * 2
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.use_checkpoint = use_checkpoint
self.dwconv1 = ResDWConv(dim, kernel_size=3)
self.norm1 = norm_layer(dim)
self.weight_query = nn.Sequential(
nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_key = nn.Sequential(
nn.AdaptiveAvgPool2d(7),
nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_value = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
)
self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)
self.fusion_proj = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
)
self.lepe = nn.Sequential(
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
)
self.se_layer = SEModule(dim)
self.gate = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
nn.SiLU(),
)
self.proj = nn.Sequential(
nn.BatchNorm2d(dim),
nn.Conv2d(dim, dim, kernel_size=1),
)
self.dwconv2 = ResDWConv(dim, kernel_size=3)
self.norm2 = norm_layer(dim)
self.mlp = nn.Sequential(
nn.Conv2d(dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, dim, kernel_size=1),
)
self.ls1 = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls2 = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.get_rpb()
def get_rpb(self):
self.rpb_size1 = 2 * self.smk_size - 1
self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))
self.rpb_size2 = 2 * self.kernel_size - 1
self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))
nn.init.trunc_normal_(self.rpb1, std=0.02)
nn.init.trunc_normal_(self.rpb2, std=0.02)
@torch.no_grad()
def generate_idx(self, kernel_size):
rpb_size = 2 * kernel_size - 1
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)
return (idx_h, idx_w, idx_k)
def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):
"""
RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3
"""
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size//2] = height - (kernel_size-1)
num_repeat_w[kernel_size//2] = width - (kernel_size-1)
bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + idx_k
bias_idx = bias_idx.reshape(-1, int(kernel_size**2))
bias_idx = torch.flip(bias_idx, [0])
rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]
rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))
return attn + rpb
def reparm(self):
for m in self.modules():
if isinstance(m, DilatedReparamBlock):
m.merge_dilated_branches()
def _forward_inner(self, x):
input_resolution = x.shape[2:]
B, C, H, W = x.shape
x = self.dwconv1(x)
identity = x
x = self.norm1(x)
gate = self.gate(x)
lepe = self.lepe(x)
is_pad = False
if min(H, W) < self.kernel_size:
is_pad = True
if H < W:
size = (self.kernel_size, int(self.kernel_size / H * W))
else:
size = (int(self.kernel_size / W * H), self.kernel_size)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
H, W = size
query = self.weight_query(x) * self.scale
key = self.weight_key(x)
value = self.weight_value(x)
query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
weight = einsum(query, key, 'b g c n, b g c l -> b g n l')
weight = rearrange(weight, 'b g n l -> b l g n').contiguous()
weight = self.weight_proj(weight)
weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)
attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)
rpb1_idx = self.generate_idx(self.smk_size)
rpb2_idx = self.generate_idx(self.kernel_size)
attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)
attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)
attn1 = torch.softmax(attn1, dim=-1)
attn2 = torch.softmax(attn2, dim=-1)
value = rearrange(value, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)
if has_natten:
x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)
x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)
else:
pad1 = self.smk_size // 2
pad2 = self.kernel_size // 2
H_o1 = H - 2 * pad1
W_o1 = W - 2 * pad1
H_o2 = H - 2 * pad2
W_o2 = W - 2 * pad2
v1 = rearrange(value[0], 'b g h w c -> b (g c) h w')
v2 = rearrange(value[1], 'b g h w c -> b (g c) h w')
v1 = F.unfold(v1, kernel_size=self.smk_size).reshape(B, -1, H_o1, W_o1)
v2 = F.unfold(v2, kernel_size=self.kernel_size).reshape(B, -1, H_o2, W_o2)
v1 = F.pad(v1, (pad1, pad1, pad1, pad1), mode='replicate')
v2 = F.pad(v2, (pad2, pad2, pad2, pad2), mode='replicate')
v1 = rearrange(v1, 'b (g c k) h w -> b g c h w k', g=self.num_heads, k=self.smk_size**2, h=H, w=W)
v2 = rearrange(v2, 'b (g c k) h w -> b g c h w k', g=self.num_heads, k=self.kernel_size**2, h=H, w=W)
x1 = einsum(attn1, v1, 'b g h w k, b g c h w k -> b g h w c')
x2 = einsum(attn2, v2, 'b g h w k, b g c h w k -> b g h w c')
x = torch.cat([x1, x2], dim=1)
x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)
if is_pad:
x = F.adaptive_avg_pool2d(x, input_resolution)
x = self.fusion_proj(x)
x = x + lepe
x = self.se_layer(x)
x = gate * x
x = self.proj(x)
if self.res_scale:
x = self.ls1(identity) + self.drop_path(x)
else:
x = identity + self.drop_path(self.ls1(x))
x = self.dwconv2(x)
if self.res_scale:
x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
return x
def forward(self, x):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self._forward_inner, x, use_reentrant=False)
else:
x = self._forward_inner(x)
return x
if __name__ == '__main__':
from timm.utils import random_seed
random_seed(6)
x = torch.randn(1, 64, 32, 32).cuda()
model = ContMixBlock(dim=64,
num_heads=2,
kernel_size=13,
smk_size=5,
mlp_ratio=4,
res_scale=True,
ls_init_value=1,
drop_path=0,
norm_layer=LayerNorm2d,
use_gemm=True,
deploy=False,
use_checkpoint=False)
print(model)
model.cuda()
model.eval()
y = model(x)
print(y.shape)
# Reparametrize model, more details can be found at:
# https://github.com/AILab-CVC/UniRepLKNet/tree/main
model.reparm()
z = model(x)
# Showing difference between original and reparametrized model
print((z - y).abs().sum() / y.abs().sum())
================================================
FILE: models/overlock.py
================================================
'''
This is an official implementation of OverLoCK model proposed in the paper:
https://arxiv.org/abs/2502.20087
'''
import torch
import timm
import torch.distributed
import torch.nn.functional as F
from torch import nn
from einops import rearrange, einsum
from natten.functional import na2d_av
from mmengine.runner import load_checkpoint
from torch.utils.checkpoint import checkpoint
from timm.models.layers import DropPath, to_2tuple
from timm.models.registry import register_model
def get_conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
attempt_use_lk_impl=True):
kernel_size = to_2tuple(kernel_size)
if padding is None:
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
padding = to_2tuple(padding)
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
if attempt_use_lk_impl and need_large_impl:
print('---------------- trying to import iGEMM implementation for large-kernel conv')
try:
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
print('---------------- found iGEMM implementation ')
except:
DepthWiseConv2dImplicitGEMM = None
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
and out_channels == groups and stride == 1 and dilation == 1:
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def get_bn(dim, use_sync_bn=False):
if use_sync_bn:
return nn.SyncBatchNorm(dim)
else:
return nn.BatchNorm2d(dim)
def fuse_bn(conv, bn):
conv_bias = 0 if conv.bias is None else conv.bias
std = (bn.running_var + bn.eps).sqrt()
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
def convert_dilated_to_nondilated(kernel, dilate_rate):
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
if kernel.size(1) == 1:
# This is a DW kernel
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
return dilated
else:
# This is a dense or group-wise (but not DW) kernel
slices = []
for i in range(kernel.size(1)):
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
slices.append(dilated)
return torch.cat(slices, dim=1)
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
large_k = large_kernel.size(2)
dilated_k = dilated_kernel.size(2)
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
return merged_kernel
def stem(in_chans=3, embed_dim=96):
return nn.Sequential(
nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim)
)
def downsample(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_dim),
)
class SEModule(nn.Module):
def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):
super().__init__()
inner_dim = max(16, dim // red)
self.proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(dim, inner_dim, kernel_size=1),
inner_act(),
nn.Conv2d(inner_dim, dim, kernel_size=1),
out_act(),
)
def forward(self, x):
x = x * self.proj(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,
requires_grad=True)
self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)
def forward(self, x):
x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])
return x
class LayerNorm2d(nn.LayerNorm):
def __init__(self, dim):
super().__init__(normalized_shape=dim, eps=1e-6)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c')
x = super().forward(x)
x = rearrange(x, 'b h w c -> b c h w')
return x.contiguous()
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
We assume the inputs to this layer are (N, C, H, W)
"""
def __init__(self, dim, use_bias=True):
super().__init__()
self.use_bias = use_bias
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
if self.use_bias:
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
if self.use_bias:
return (self.gamma * Nx + 1) * x + self.beta
else:
return (self.gamma * Nx + 1) * x
class DilatedReparamBlock(nn.Module):
"""
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
We assume the inputs to this block are (N, C, H, W)
"""
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
super().__init__()
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
attempt_use_lk_impl=attempt_use_lk_impl)
self.attempt_use_lk_impl = attempt_use_lk_impl
# Default settings. We did not tune them carefully. Different settings may work better.
if kernel_size == 19:
self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]
self.dilates = [1, 1, 1, 2, 4, 5, 7]
elif kernel_size == 17:
self.kernel_sizes = [5, 7, 9, 3, 3, 3]
self.dilates = [1, 1, 2, 4, 5, 7]
elif kernel_size == 15:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 5, 7]
elif kernel_size == 13:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 11:
self.kernel_sizes = [5, 7, 5, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 9:
self.kernel_sizes = [5, 7, 5, 3, 3]
self.dilates = [1, 1, 2, 3, 4]
elif kernel_size == 7:
self.kernel_sizes = [5, 3, 3, 3]
self.dilates = [1, 1, 2, 3]
elif kernel_size == 5:
self.kernel_sizes = [3, 3]
self.dilates = [1, 2]
else:
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
if not deploy:
self.origin_bn = get_bn(channels, use_sync_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
bias=False))
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
def forward(self, x):
if not hasattr(self, 'origin_bn'): # deploy mode
return self.lk_origin(x)
out = self.origin_bn(self.lk_origin(x))
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
out = out + bn(conv(x))
return out
def merge_dilated_branches(self):
if hasattr(self, 'origin_bn'):
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
branch_k, branch_b = fuse_bn(conv, bn)
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
origin_b += branch_b
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
attempt_use_lk_impl=self.attempt_use_lk_impl)
merged_conv.weight.data = origin_k
merged_conv.bias.data = origin_b
self.lk_origin = merged_conv
self.__delattr__('origin_bn')
for k, r in zip(self.kernel_sizes, self.dilates):
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
class CTXDownsample(nn.Module):
def __init__(self, dim, h_dim):
super().__init__()
self.x_proj = nn.Sequential(
nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim)
)
self.h_proj = nn.Sequential(
nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim//4)
)
def forward(self, x, ctx):
x = self.x_proj(x)
ctx = self.h_proj(ctx)
return (x, ctx)
class ResDWConv(nn.Conv2d):
'''
Depthwise convolution with residual connection
'''
def __init__(self, dim, kernel_size=3):
super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
def forward(self, x):
x = x + super().forward(x)
return x
class RepConvBlock(nn.Module):
def __init__(self,
dim=64,
kernel_size=7,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
use_gemm=False,
deploy=False,
use_checkpoint=False):
super().__init__()
self.res_scale = res_scale
self.use_checkpoint = use_checkpoint
mlp_dim = int(dim*mlp_ratio)
self.dwconv = ResDWConv(dim, kernel_size=3)
self.proj = nn.Sequential(
norm_layer(dim),
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
SEModule(dim),
nn.Conv2d(dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, dim, kernel_size=1),
DropPath(drop_path) if drop_path > 0 else nn.Identity(),
)
self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
def forward_features(self, x):
x = self.dwconv(x)
if self.res_scale:
x = self.ls(x) + self.proj(x)
else:
drop_path = self.proj[-1]
x = x + drop_path(self.ls(self.proj[:-1](x)))
return x
def forward(self, x):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self.forward_features, x, use_reentrant=False)
else:
x = self.forward_features(x)
return x
class DynamicConvBlock(nn.Module):
def __init__(self,
dim=64,
ctx_dim=32,
kernel_size=7,
smk_size=5,
num_heads=2,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
is_first=False,
is_last=False,
use_gemm=False,
deploy=False,
use_checkpoint=False,
**kwargs):
super().__init__()
ctx_dim = ctx_dim // 4
out_dim = dim + ctx_dim
mlp_dim = int(dim*mlp_ratio)
self.kernel_size = kernel_size
self.res_scale = res_scale
self.use_gemm = use_gemm
self.smk_size = smk_size
self.num_heads = num_heads * 2
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.is_first = is_first
self.is_last = is_last
self.use_checkpoint = use_checkpoint
if not is_first:
self.x_scale = LayerScale(ctx_dim, init_value=1)
self.h_scale = LayerScale(ctx_dim, init_value=1)
self.dwconv1 = ResDWConv(out_dim, kernel_size=3)
self.norm1 = norm_layer(out_dim)
self.fusion = nn.Sequential(
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),
nn.BatchNorm2d(out_dim),
nn.GELU(),
nn.Conv2d(out_dim, dim, kernel_size=1),
GRN(dim),
)
self.weight_query = nn.Sequential(
nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_key = nn.Sequential(
nn.AdaptiveAvgPool2d(7),
nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)
self.dyconv_proj = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
)
self.lepe = nn.Sequential(
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
)
self.se_layer = SEModule(dim)
self.gate = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
nn.SiLU(),
)
self.proj = nn.Sequential(
nn.BatchNorm2d(dim),
nn.Conv2d(dim, out_dim, kernel_size=1),
)
self.dwconv2 = ResDWConv(out_dim, kernel_size=3)
self.norm2 = norm_layer(out_dim)
self.mlp = nn.Sequential(
nn.Conv2d(out_dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, out_dim, kernel_size=1),
)
self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.get_rpb()
def get_rpb(self):
self.rpb_size1 = 2 * self.smk_size - 1
self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))
self.rpb_size2 = 2 * self.kernel_size - 1
self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))
nn.init.zeros_(self.rpb1)
nn.init.zeros_(self.rpb2)
@torch.no_grad()
def generate_idx(self, kernel_size):
rpb_size = 2 * kernel_size - 1
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)
return (idx_h, idx_w, idx_k)
def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):
"""
RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3
"""
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size//2] = height - (kernel_size-1)
num_repeat_w[kernel_size//2] = width - (kernel_size-1)
bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + idx_k
bias_idx = bias_idx.reshape(-1, int(kernel_size**2))
bias_idx = torch.flip(bias_idx, [0])
rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]
rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))
return attn + rpb
def _forward_inner(self, x, h_x, h_r):
input_resoltion = x.shape[2:]
B, C, H, W = x.shape
B, C_h, H_h, W_h = h_x.shape
if not self.is_first:
h_x = self.x_scale(h_x) + self.h_scale(h_r)
x_f = torch.cat([x, h_x], dim=1)
x_f = self.dwconv1(x_f)
identity = x_f
x_f = self.norm1(x_f)
x = self.fusion(x_f)
gate = self.gate(x)
lepe = self.lepe(x)
is_pad = False
if min(H, W) < self.kernel_size:
is_pad = True
if H < W:
size = (self.kernel_size, int(self.kernel_size / H * W))
else:
size = (int(self.kernel_size / W * H), self.kernel_size)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)
H, W = size
query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)
query = self.weight_query(query) * self.scale
key = self.weight_key(key)
query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
weight = einsum(query, key, 'b g c n, b g c l -> b g n l')
weight = rearrange(weight, 'b g n l -> b l g n').contiguous()
weight = self.weight_proj(weight)
weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)
attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)
rpb1_idx = self.generate_idx(self.smk_size)
rpb2_idx = self.generate_idx(self.kernel_size)
attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)
attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)
attn1 = torch.softmax(attn1, dim=-1)
attn2 = torch.softmax(attn2, dim=-1)
value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)
x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)
x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)
x = torch.cat([x1, x2], dim=1)
x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)
if is_pad:
x = F.adaptive_avg_pool2d(x, input_resoltion)
x = self.dyconv_proj(x)
x = x + lepe
x = self.se_layer(x)
x = gate * x
x = self.proj(x)
if self.res_scale:
x = self.ls1(identity) + self.drop_path(x)
else:
x = identity + self.drop_path(self.ls1(x))
x = self.dwconv2(x)
if self.res_scale:
x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
if self.is_last:
return (x, None)
else:
l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)
return (l_x, h_x)
def forward(self, x, h_x, h_r):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)
else:
x = self._forward_inner(x, h_x, h_r)
return x
class OverLoCK(nn.Module):
'''
An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels
https://arxiv.org/abs/2502.20087
'''
def __init__(self,
depth=[2, 2, 2, 2],
sub_depth=[4, 2],
in_chans=3,
embed_dim=[96, 192, 384, 768],
kernel_size=[7, 7, 7, 7],
mlp_ratio=[4, 4, 4, 4],
sub_mlp_ratio=[4, 4],
sub_num_heads=[4, 8],
ls_init_value=[None, None, 1, 1],
res_scale=True,
smk_size=5,
deploy=False,
use_gemm=True,
use_ds=True,
drop_rate=0,
drop_path_rate=0,
norm_layer=LayerNorm2d,
projection=1024,
num_classes=1000,
use_checkpoint=[0, 0, 0, 0],
):
super().__init__()
fusion_dim = embed_dim[-1] + embed_dim[-1]//4
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed1 = stem(in_chans, embed_dim[0])
self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])
self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])
self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])
self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)
self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]
self.blocks1 = nn.ModuleList()
self.blocks2 = nn.ModuleList()
self.blocks3 = nn.ModuleList()
self.blocks4 = nn.ModuleList()
self.sub_blocks3 = nn.ModuleList()
self.sub_blocks4 = nn.ModuleList()
for i in range(depth[0]):
self.blocks1.append(
RepConvBlock(
dim=embed_dim[0],
kernel_size=kernel_size[0],
mlp_ratio=mlp_ratio[0],
ls_init_value=ls_init_value[0],
res_scale=res_scale,
drop_path=dpr[i],
norm_layer=norm_layer,
use_gemm=use_gemm,
deploy=deploy,
use_checkpoint=(i 0 else nn.Identity()
)
# Main Cls Head
self.head = nn.Sequential(
nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),
nn.BatchNorm2d(projection),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()
)
self.apply(self._init_weights)
if torch.distributed.is_initialized():
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def reparam(self):
for m in self.modules():
if isinstance(m, DilatedReparamBlock):
m.merge_dilated_branches()
def forward_pre_features(self, x):
x = self.patch_embed1(x)
for blk in self.blocks1:
x = blk(x)
x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
return x
def forward_base_features(self, x):
x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
ctx = self.patch_embed4(x)
for blk in self.blocks4:
ctx = blk(ctx)
return (x, ctx)
def forward_sub_features(self, x, ctx):
ctx_cls = ctx
ctx_ori = self.high_level_proj(ctx)
ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)
for idx, blk in enumerate(self.sub_blocks3):
if idx == 0:
ctx = ctx_up
x, ctx = blk(x, ctx, ctx_up)
x, ctx = self.patch_embedx(x, ctx)
for idx, blk in enumerate(self.sub_blocks4):
x, ctx = blk(x, ctx, ctx_ori)
return (x, ctx_cls)
def forward_features(self, x):
x = self.forward_pre_features(x)
x, ctx = self.forward_base_features(x)
x, ctx_cls = self.forward_sub_features(x, ctx)
return (x, ctx_cls)
def forward(self, x):
x, ctx = self.forward_features(x)
x = self.head(x).flatten(1)
if hasattr(self, 'aux_head') and self.training:
ctx = self.aux_head(ctx).flatten(1)
return dict(main=x, aux=ctx)
return x
def _cfg(url=None, **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'crop_pct': 0.9,
'interpolation': 'bicubic', # 'bilinear' or 'bicubic'
'mean': timm.data.IMAGENET_DEFAULT_MEAN,
'std': timm.data.IMAGENET_DEFAULT_STD,
'classifier': 'classifier',
**kwargs,
}
@register_model
def overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[2, 2, 3, 2],
sub_depth=[6, 2],
embed_dim=[56, 112, 256, 336],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 6],
sub_mlp_ratio=[3, 3],
**kwargs
)
model.default_cfg = _cfg(crop_pct=0.925)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'
load_checkpoint(model, pretrained)
return model
@register_model
def overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[4, 4, 6, 2],
sub_depth=[12, 2],
embed_dim=[64, 128, 256, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 8],
sub_mlp_ratio=[3, 3],
**kwargs
)
model.default_cfg = _cfg(crop_pct=0.95)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'
load_checkpoint(model, pretrained)
return model
@register_model
def overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[6, 6, 8, 3],
sub_depth=[16, 3],
embed_dim=[64, 128, 320, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[8, 16],
sub_mlp_ratio=[3, 3],
**kwargs
)
model.default_cfg = _cfg(crop_pct=0.95)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'
load_checkpoint(model, pretrained)
return model
@register_model
def overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[8, 8, 10, 4],
sub_depth=[20, 4],
embed_dim=[80, 160, 384, 576],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[6, 9],
sub_mlp_ratio=[3, 3],
**kwargs
)
model.default_cfg = _cfg(crop_pct=0.975)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'
load_checkpoint(model, pretrained)
return model
'''
Reparameterized versions of OverLoCK models are given,
which offer improved efficiency in terms of inference speed and memory consumption.
Note: these variants may come at the cost of lower accuracy *during fine-tuning*,
when compared to their original counterparts.
More details about model reparameterization can be found at:
https://arxiv.org/abs/2311.15599
'''
@register_model
def overlock_xt_reparam(pretrained=False, pretrained_cfg=None, **kwargs):
model = overlock_xt(deploy=True)
model.default_cfg = _cfg(crop_pct=0.925)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224_reparam.pth'
load_checkpoint(model, pretrained)
return model
@register_model
def overlock_t_reparam(pretrained=False, pretrained_cfg=None, **kwargs):
model = overlock_t(deploy=True)
model.default_cfg = _cfg(crop_pct=0.95)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224_reparam.pth'
load_checkpoint(model, pretrained)
return model
@register_model
def overlock_s_reparam(pretrained=False, pretrained_cfg=None, **kwargs):
model = overlock_s(deploy=True)
model.default_cfg = _cfg(crop_pct=0.95)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224_reparam.pth'
return model
@register_model
def overlock_b_reparam(pretrained=False, pretrained_cfg=None, **kwargs):
model = overlock_b(deploy=True)
model.default_cfg = _cfg(crop_pct=0.975)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224_reparam.pth'
return model
if __name__ == '__main__':
device = torch.device('cuda')
model = overlock_xt(pretrained=True).to(device) # load pretrained weights
model.eval()
x = torch.randn(1, 3, 224, 224).to(device)
y = model(x)
print(y.shape)
# Reparametrized model, more details can be found at:
# https://github.com/AILab-CVC/UniRepLKNet/tree/main
model = overlock_xt_reparam(pretrained=True).to(device) # load pretrained weights
model.eval()
z = model(x)
print(z.shape)
# Showing difference between original and reparametrized model
print((y-z).mean())
================================================
FILE: scripts/train_b_model.sh
================================================
#!/usr/bin/env bash
python3 -m torch.distributed.launch \
--master_port=$((RANDOM+8888)) \
--nproc_per_node=8 \
train.py \
--data-dir /data/dataset/imagenet/ \
--batch-size 256 \
--model overlock_b \
--lr 1e-3 \
--auto-lr \
--drop-path 0.5 \
--epochs 300 \
--warmup-epochs 5 \
--workers 10 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/overlock_b/ \
--native-amp \
--clip-grad 5
================================================
FILE: scripts/train_s_model.sh
================================================
#!/usr/bin/env bash
python3 -m torch.distributed.launch \
--master_port=$((RANDOM+8888)) \
--nproc_per_node=8 \
train.py \
--data-dir /data/dataset/imagenet/ \
--batch-size 256 \
--model overlock_s \
--lr 1e-3 \
--auto-lr \
--drop-path 0.4 \
--epochs 300 \
--warmup-epochs 5 \
--workers 10 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/overlock_s/ \
--native-amp \
--clip-grad 5
================================================
FILE: scripts/train_t_model.sh
================================================
#!/usr/bin/env bash
python3 -m torch.distributed.launch \
--master_port=$((RANDOM+8888)) \
--nproc_per_node=8 \
train.py \
--data-dir /data/dataset/imagenet/ \
--batch-size 256 \
--model overlock_t \
--lr 1e-3 \
--auto-lr \
--drop-path 0.15 \
--epochs 300 \
--warmup-epochs 5 \
--workers 10 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/overlock_t/ \
--native-amp \
--clip-grad 5
================================================
FILE: scripts/train_xt_model.sh
================================================
#!/usr/bin/env bash
python3 -m torch.distributed.launch \
--master_port=$((RANDOM+8888)) \
--nproc_per_node=8 \
train.py \
--data-dir /data/dataset/imagenet/ \
--batch-size 256 \
--model overlock_xt \
--lr 1e-3 \
--auto-lr \
--drop-path 0.1 \
--epochs 300 \
--warmup-epochs 5 \
--workers 10 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/overlock_xt/ \
--native-amp \
--clip-grad 5
================================================
FILE: segmentation/configs/_base_/datasets/ade20k.py
================================================
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = '/grp01/cs_yzyu/dataset/ADEChallengeData2016/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
# dict(type='AlignResize', keep_ratio=True, size_divisor=32),
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32, interpolation='bicubic'),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
================================================
FILE: segmentation/configs/_base_/default_runtime.py
================================================
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
================================================
FILE: segmentation/configs/_base_/models/fpn_r50.py
================================================
# copied from mmsegmentaion official config
# https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/fpn_r50.py
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: segmentation/configs/_base_/models/upernet_r50.py
================================================
# copied from mmsegmentation official config
# https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_r50.py
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: segmentation/configs/_base_/models/upernet_transnext.py
================================================
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
================================================
FILE: segmentation/configs/_base_/schedules/schedule_160k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU')
================================================
FILE: segmentation/configs/_base_/schedules/schedule_20k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=20000)
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
================================================
FILE: segmentation/configs/_base_/schedules/schedule_40k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=40000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')
================================================
FILE: segmentation/configs/_base_/schedules/schedule_80k.py
================================================
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU')
================================================
FILE: segmentation/configs/overlock/upernet_overlock_base_ade20k_8xb2.py
================================================
_base_ = [
'../_base_/models/upernet_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained=None,
backbone=dict(
_delete_=True,
type='overlock_b',
pretrained=True,
drop_path_rate=0.5,
),
decode_head=dict(
in_index=[0, 1, 2, 3],
in_channels=[80, 160, 528, 720],
num_classes=150,
),
auxiliary_head=dict(
in_index=2,
in_channels=528,
num_classes=150
),
)
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
data = dict(samples_per_gpu=2) # as gpus = 8
checkpoint_config = dict(interval=8000, max_keep_ckpts=1)
evaluation = dict(interval=8000, save_best='mIoU')
# place holder for new verison mmseg compatiability
resume_from = None
device = 'cuda'
# # AMP (faster but may meet nan loss) ->
# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 = dict()
================================================
FILE: segmentation/configs/overlock/upernet_overlock_small_ade20k_8xb2.py
================================================
_base_ = [
'../_base_/models/upernet_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained=None,
backbone=dict(
_delete_=True,
type='overlock_s',
pretrained=True,
drop_path_rate=0.3,
),
decode_head=dict(
in_index=[0, 1, 2, 3],
in_channels=[64, 128, 448, 640],
num_classes=150,
),
auxiliary_head=dict(
in_index=2,
in_channels=448,
num_classes=150
),
)
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
data = dict(samples_per_gpu=2) # as gpus = 8
checkpoint_config = dict(interval=8000, max_keep_ckpts=1)
evaluation = dict(interval=8000, save_best='mIoU')
# place holder for new verison mmseg compatiability
resume_from = None
device = 'cuda'
# # AMP (faster but may meet nan loss) ->
# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 = dict()
================================================
FILE: segmentation/configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py
================================================
_base_ = [
'../_base_/models/upernet_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained=None,
backbone=dict(
_delete_=True,
type='overlock_t',
pretrained=True,
drop_path_rate=0.2
),
decode_head=dict(
in_index=[0, 1, 2, 3],
in_channels=[64, 128, 384, 640],
num_classes=150,
),
auxiliary_head=dict(
in_index=2,
in_channels=384,
num_classes=150
),
)
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
data = dict(samples_per_gpu=2) # as gpus = 8
checkpoint_config = dict(interval=8000, max_keep_ckpts=1)
evaluation = dict(interval=8000, save_best='mIoU')
# place holder for new verison mmseg compatiability
resume_from = None
device = 'cuda'
# # AMP (faster but may meet nan loss) ->
# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 = dict()
================================================
FILE: segmentation/mmseg_custom/__init__.py
================================================
from .align_resize import AlignResize
================================================
FILE: segmentation/mmseg_custom/align_resize.py
================================================
import mmcv
import numpy as np
from mmseg.datasets.builder import PIPELINES
# from IPython import embed
# from numpy import random
# from mmcv.utils import deprecated_api_warning, is_tuple_of
@PIPELINES.register_module()
class AlignResize(object):
"""Resize images & seg. Align"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
size_divisor=32,
interpolation=None):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given img_scale=None and a range of image ratio
# mode 2: given a scale and a range of image ratio
assert self.img_scale is None or len(self.img_scale) == 1
else:
# mode 3 and 4: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.size_divisor = size_divisor
self.interpolation = interpolation
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
``img_scale`` is sampled scale and None is just a placeholder
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
``scale`` is sampled ratio multiplied with ``img_scale`` and
None is just a placeholder to be consistent with
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
if self.img_scale is None:
h, w = results['img'].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _align(self, img, size_divisor, interpolation=None):
align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor
align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor
if interpolation == None:
img = mmcv.imresize(img, (align_w, align_h))
else:
img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation)
return img
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
#### align ####
img = self._align(img, self.size_divisor, interpolation='bicubic')
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'], results['scale'], return_scale=True)
h, w = img.shape[:2]
assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
"img size not align. h:{} w:{}".format(h,w)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape # in case that there is no padding
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
h, w = gt_seg.shape[:2]
assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
"gt_seg size not align. h:{} w:{}".format(h, w)
results[key] = gt_seg
def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(img_scale={self.img_scale}, '
f'multiscale_mode={self.multiscale_mode}, '
f'ratio_range={self.ratio_range}, '
f'keep_ratio={self.keep_ratio})')
return repr_str
================================================
FILE: segmentation/models/__init__.py
================================================
from .overlock import *
================================================
FILE: segmentation/models/overlock.py
================================================
'''
This is an official implementation of OverLoCK model proposed in the paper:
https://arxiv.org/abs/2502.20087
'''
import torch
import timm
import torch.distributed
import torch.nn.functional as F
from torch import nn
from einops import rearrange, einsum
from natten.functional import na2d_av
from torch.utils.checkpoint import checkpoint
from timm.models.layers import DropPath, to_2tuple
from timm.models.registry import register_model
from mmseg.models.builder import MODELS
from mmseg.utils import get_root_logger
from mmcv.runner import load_checkpoint
def get_conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
attempt_use_lk_impl=True):
kernel_size = to_2tuple(kernel_size)
if padding is None:
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
padding = to_2tuple(padding)
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
if attempt_use_lk_impl and need_large_impl:
print('---------------- trying to import iGEMM implementation for large-kernel conv')
try:
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
print('---------------- found iGEMM implementation ')
except:
DepthWiseConv2dImplicitGEMM = None
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
and out_channels == groups and stride == 1 and dilation == 1:
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def get_bn(dim, use_sync_bn=False):
if use_sync_bn:
return nn.SyncBatchNorm(dim)
else:
return nn.BatchNorm2d(dim)
def fuse_bn(conv, bn):
conv_bias = 0 if conv.bias is None else conv.bias
std = (bn.running_var + bn.eps).sqrt()
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
def convert_dilated_to_nondilated(kernel, dilate_rate):
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
if kernel.size(1) == 1:
# This is a DW kernel
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
return dilated
else:
# This is a dense or group-wise (but not DW) kernel
slices = []
for i in range(kernel.size(1)):
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
slices.append(dilated)
return torch.cat(slices, dim=1)
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
large_k = large_kernel.size(2)
dilated_k = dilated_kernel.size(2)
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
return merged_kernel
def stem(in_chans=3, embed_dim=96):
return nn.Sequential(
nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(embed_dim)
)
def downsample(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_dim),
)
class SEModule(nn.Module):
def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):
super().__init__()
inner_dim = max(16, dim // red)
self.proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(dim, inner_dim, kernel_size=1),
inner_act(),
nn.Conv2d(inner_dim, dim, kernel_size=1),
out_act(),
)
def forward(self, x):
x = x * self.proj(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,
requires_grad=True)
self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)
def forward(self, x):
x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])
return x
class LayerNorm2d(nn.LayerNorm):
def __init__(self, dim):
super().__init__(normalized_shape=dim, eps=1e-6)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c')
x = super().forward(x)
x = rearrange(x, 'b h w c -> b c h w')
return x.contiguous()
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
We assume the inputs to this layer are (N, C, H, W)
"""
def __init__(self, dim, use_bias=True):
super().__init__()
self.use_bias = use_bias
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
if self.use_bias:
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
if self.use_bias:
return (self.gamma * Nx + 1) * x + self.beta
else:
return (self.gamma * Nx + 1) * x
class DilatedReparamBlock(nn.Module):
"""
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
We assume the inputs to this block are (N, C, H, W)
"""
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
super().__init__()
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
attempt_use_lk_impl=attempt_use_lk_impl)
self.attempt_use_lk_impl = attempt_use_lk_impl
# Default settings. We did not tune them carefully. Different settings may work better.
if kernel_size == 19:
self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]
self.dilates = [1, 1, 1, 2, 4, 5, 7]
elif kernel_size == 17:
self.kernel_sizes = [5, 7, 9, 3, 3, 3]
self.dilates = [1, 1, 2, 4, 5, 7]
elif kernel_size == 15:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 5, 7]
elif kernel_size == 13:
self.kernel_sizes = [5, 7, 7, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 11:
self.kernel_sizes = [5, 7, 5, 3, 3, 3]
self.dilates = [1, 1, 2, 3, 4, 5]
elif kernel_size == 9:
self.kernel_sizes = [5, 7, 5, 3, 3]
self.dilates = [1, 1, 2, 3, 4]
elif kernel_size == 7:
self.kernel_sizes = [5, 3, 3, 3]
self.dilates = [1, 1, 2, 3]
elif kernel_size == 5:
self.kernel_sizes = [3, 3]
self.dilates = [1, 2]
else:
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
if not deploy:
self.origin_bn = get_bn(channels, use_sync_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
bias=False))
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
def forward(self, x):
if not hasattr(self, 'origin_bn'): # deploy mode
return self.lk_origin(x)
out = self.origin_bn(self.lk_origin(x))
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
out = out + bn(conv(x))
return out
def merge_dilated_branches(self):
if hasattr(self, 'origin_bn'):
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
for k, r in zip(self.kernel_sizes, self.dilates):
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
branch_k, branch_b = fuse_bn(conv, bn)
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
origin_b += branch_b
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
attempt_use_lk_impl=self.attempt_use_lk_impl)
merged_conv.weight.data = origin_k
merged_conv.bias.data = origin_b
self.lk_origin = merged_conv
self.__delattr__('origin_bn')
for k, r in zip(self.kernel_sizes, self.dilates):
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
class CTXDownsample(nn.Module):
def __init__(self, dim, h_dim):
super().__init__()
self.x_proj = nn.Sequential(
nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim)
)
self.h_proj = nn.Sequential(
nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(h_dim//4)
)
def forward(self, x, ctx):
x = self.x_proj(x)
ctx = self.h_proj(ctx)
return (x, ctx)
class ResDWConv(nn.Conv2d):
'''
Depthwise convolution with residual connection
'''
def __init__(self, dim, kernel_size=3):
super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
def forward(self, x):
x = x + super().forward(x)
return x
class RepConvBlock(nn.Module):
def __init__(self,
dim=64,
kernel_size=7,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
use_gemm=False,
deploy=False,
use_checkpoint=False):
super().__init__()
self.res_scale = res_scale
self.use_checkpoint = use_checkpoint
mlp_dim = int(dim*mlp_ratio)
self.dwconv = ResDWConv(dim, kernel_size=3)
self.proj = nn.Sequential(
norm_layer(dim),
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
SEModule(dim),
nn.Conv2d(dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, dim, kernel_size=1),
DropPath(drop_path) if drop_path > 0 else nn.Identity(),
)
self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
def forward_features(self, x):
x = self.dwconv(x)
if self.res_scale:
x = self.ls(x) + self.proj(x)
else:
drop_path = self.proj[-1]
x = x + drop_path(self.ls(self.proj[:-1](x)))
return x
def forward(self, x):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self.forward_features, x, use_reentrant=False)
else:
x = self.forward_features(x)
return x
class DynamicConvBlock(nn.Module):
def __init__(self,
dim=64,
ctx_dim=32,
kernel_size=7,
smk_size=5,
num_heads=2,
mlp_ratio=4,
ls_init_value=None,
res_scale=False,
drop_path=0,
norm_layer=LayerNorm2d,
is_first=False,
is_last=False,
use_gemm=False,
deploy=False,
use_checkpoint=False,
**kwargs):
super().__init__()
ctx_dim = ctx_dim // 4
out_dim = dim + ctx_dim
mlp_dim = int(dim*mlp_ratio)
self.kernel_size = kernel_size
self.res_scale = res_scale
self.use_gemm = use_gemm
self.smk_size = smk_size
self.num_heads = num_heads * 2
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.is_first = is_first
self.is_last = is_last
self.use_checkpoint = use_checkpoint
if not is_first:
self.x_scale = LayerScale(ctx_dim, init_value=1)
self.h_scale = LayerScale(ctx_dim, init_value=1)
self.dwconv1 = ResDWConv(out_dim, kernel_size=3)
self.norm1 = norm_layer(out_dim)
self.fusion = nn.Sequential(
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),
nn.BatchNorm2d(out_dim),
nn.GELU(),
nn.Conv2d(out_dim, dim, kernel_size=1),
GRN(dim),
)
self.weight_query = nn.Sequential(
nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_key = nn.Sequential(
nn.AdaptiveAvgPool2d(7),
nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),
nn.BatchNorm2d(dim//2),
)
self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)
self.dyconv_proj = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
)
self.lepe = nn.Sequential(
DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),
nn.BatchNorm2d(dim),
)
self.se_layer = SEModule(dim)
self.gate = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
nn.BatchNorm2d(dim),
nn.SiLU(),
)
self.proj = nn.Sequential(
nn.BatchNorm2d(dim),
nn.Conv2d(dim, out_dim, kernel_size=1),
)
self.dwconv2 = ResDWConv(out_dim, kernel_size=3)
self.norm2 = norm_layer(out_dim)
self.mlp = nn.Sequential(
nn.Conv2d(out_dim, mlp_dim, kernel_size=1),
nn.GELU(),
ResDWConv(mlp_dim, kernel_size=3),
GRN(mlp_dim),
nn.Conv2d(mlp_dim, out_dim, kernel_size=1),
)
self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.get_rpb()
def get_rpb(self):
self.rpb_size1 = 2 * self.smk_size - 1
self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))
self.rpb_size2 = 2 * self.kernel_size - 1
self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))
nn.init.zeros_(self.rpb1)
nn.init.zeros_(self.rpb2)
@torch.no_grad()
def generate_idx(self, kernel_size):
rpb_size = 2 * kernel_size - 1
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)
return (idx_h, idx_w, idx_k)
def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):
"""
RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3
"""
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size//2] = height - (kernel_size-1)
num_repeat_w[kernel_size//2] = width - (kernel_size-1)
bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + idx_k
bias_idx = bias_idx.reshape(-1, int(kernel_size**2))
bias_idx = torch.flip(bias_idx, [0])
rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]
rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))
return attn + rpb
def _forward_inner(self, x, h_x, h_r):
input_resoltion = x.shape[2:]
B, C, H, W = x.shape
B, C_h, H_h, W_h = h_x.shape
if not self.is_first:
h_x = self.x_scale(h_x) + self.h_scale(h_r)
x_f = torch.cat([x, h_x], dim=1)
x_f = self.dwconv1(x_f)
identity = x_f
x_f = self.norm1(x_f)
x = self.fusion(x_f)
gate = self.gate(x)
lepe = self.lepe(x)
is_pad = False
if min(H, W) < self.kernel_size:
is_pad = True
if H < W:
size = (self.kernel_size, int(self.kernel_size / H * W))
else:
size = (int(self.kernel_size / W * H), self.kernel_size)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)
H, W = size
query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)
query = self.weight_query(query) * self.scale
key = self.weight_key(key)
query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)
weight = einsum(query, key, 'b g c n, b g c l -> b g n l')
weight = rearrange(weight, 'b g n l -> b l g n').contiguous()
weight = self.weight_proj(weight)
weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)
attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)
rpb1_idx = self.generate_idx(self.smk_size)
rpb2_idx = self.generate_idx(self.kernel_size)
attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)
attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)
attn1 = torch.softmax(attn1, dim=-1)
attn2 = torch.softmax(attn2, dim=-1)
value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)
x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)
x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)
x = torch.cat([x1, x2], dim=1)
x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)
if is_pad:
x = F.adaptive_avg_pool2d(x, input_resoltion)
x = self.dyconv_proj(x)
x = x + lepe
x = self.se_layer(x)
x = gate * x
x = self.proj(x)
if self.res_scale:
x = self.ls1(identity) + self.drop_path(x)
else:
x = identity + self.drop_path(self.ls1(x))
x = self.dwconv2(x)
if self.res_scale:
x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
if self.is_last:
return (x, None)
else:
l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)
return (l_x, h_x)
def forward(self, x, h_x, h_r):
if self.use_checkpoint and x.requires_grad:
x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)
else:
x = self._forward_inner(x, h_x, h_r)
return x
class OverLoCK(nn.Module):
'''
An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels
https://arxiv.org/abs/2502.20087
'''
def __init__(self,
depth=[2, 2, 2, 2],
sub_depth=[4, 2],
in_chans=3,
embed_dim=[96, 192, 384, 768],
kernel_size=[7, 7, 7, 7],
mlp_ratio=[4, 4, 4, 4],
sub_mlp_ratio=[4, 4],
sub_num_heads=[4, 8],
ls_init_value=[None, None, 1, 1],
res_scale=True,
smk_size=5,
deploy=False,
use_gemm=True,
use_ds=True,
drop_rate=0,
drop_path_rate=0,
norm_layer=LayerNorm2d,
projection=1024,
num_classes=1000,
use_checkpoint=[0, 0, 0, 0],
):
super().__init__()
fusion_dim = embed_dim[-1] + embed_dim[-1]//4
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed1 = stem(in_chans, embed_dim[0])
self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])
self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])
self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])
self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)
self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]
self.blocks1 = nn.ModuleList()
self.blocks2 = nn.ModuleList()
self.blocks3 = nn.ModuleList()
self.blocks4 = nn.ModuleList()
self.sub_blocks3 = nn.ModuleList()
self.sub_blocks4 = nn.ModuleList()
for i in range(depth[0]):
self.blocks1.append(
RepConvBlock(
dim=embed_dim[0],
kernel_size=kernel_size[0],
mlp_ratio=mlp_ratio[0],
ls_init_value=ls_init_value[0],
res_scale=res_scale,
drop_path=dpr[i],
norm_layer=norm_layer,
use_gemm=use_gemm,
deploy=deploy,
use_checkpoint=(i 0 else nn.Identity()
)
# Main Cls Head
self.head = nn.Sequential(
nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),
nn.BatchNorm2d(projection),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()
)
self.extra_norm = nn.ModuleList()
for idx in range(4):
dim = embed_dim[idx]
if idx >= 2:
dim = dim + embed_dim[-1]//4
self.extra_norm.append(norm_layer(dim))
del self.aux_head
del self.head
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def _convert_sync_batchnorm(self):
if torch.distributed.is_initialized():
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
def forward_pre_features(self, x):
outs = []
x = self.patch_embed1(x)
for blk in self.blocks1:
x = blk(x)
outs.append(self.extra_norm[0](x))
x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
outs.append(self.extra_norm[1](x))
return outs
def forward_base_features(self, x):
x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
ctx = self.patch_embed4(x)
for blk in self.blocks4:
ctx = blk(ctx)
return (x, ctx)
def forward_sub_features(self, x, ctx):
outs = []
# ctx_cls = ctx
ctx_ori = self.high_level_proj(ctx)
ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)
for idx, blk in enumerate(self.sub_blocks3):
if idx == 0:
ctx = ctx_up
x, ctx = blk(x, ctx, ctx_up)
outs.append(self.extra_norm[2](torch.cat([x, ctx], dim=1)))
x, ctx = self.patch_embedx(x, ctx)
for idx, blk in enumerate(self.sub_blocks4):
x, ctx = blk(x, ctx, ctx_ori)
outs.append(self.extra_norm[3](x))
return outs
def forward_features(self, x):
x0, x1 = self.forward_pre_features(x)
x, ctx = self.forward_base_features(x1)
x2, x3 = self.forward_sub_features(x, ctx)
return (x0, x1, x2, x3)
def forward(self, x):
x = self.forward_features(x)
return x
@MODELS.register_module()
def overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[2, 2, 3, 2],
sub_depth=[6, 2],
embed_dim=[56, 112, 256, 336],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 6],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[4, 4, 6, 2],
sub_depth=[12, 2],
embed_dim=[64, 128, 256, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[4, 8],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[6, 6, 8, 3],
sub_depth=[16, 3],
embed_dim=[64, 128, 320, 512],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[8, 16],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
@MODELS.register_module()
def overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):
model = OverLoCK(
depth=[8, 8, 10, 4],
sub_depth=[20, 4],
embed_dim=[80, 160, 384, 576],
kernel_size=[17, 15, 13, 7],
mlp_ratio=[4, 4, 4, 4],
sub_num_heads=[6, 9],
sub_mlp_ratio=[3, 3],
**kwargs
)
if pretrained:
pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
model._convert_sync_batchnorm()
return model
================================================
FILE: segmentation/readme.md
================================================
# Applying OverLoCK to Semantic Segmentation
## 1. Requirements
```
pip install mmcv-full==1.7.2 --no-cache-dir
pip install mmsegmentation==0.30.0 --no-cache-dir
```
💡 To enable torch>=2.1.0 to support mmcv 1.7.2, you need to make the following changes:
> 1️⃣ https://goo.su/XhU5vWr
> 2️⃣ https://goo.su/ogm4yO
## 2. Data Preparation
Prepare ADE20K dataset according to the [guidelines](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md).
## 3. Main Results on ADE20K using UperNet framework
| Backbone | Pretrain | Schedule | mIoU | Config | Download |
|:-------------:|:-----------:|:--------:|--------|:-------------------------------------------------------:|:----------:|
| OverLoCK-T | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth)| 160K | 50.3 | [config](configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_tiny_ade20k.pth) |
| OverLoCK-S | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth)| 160K |51.3 | [config](configs/overlock/upernet_overlock_small_ade20k_8xb2.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_small_ade20k.pth) |
| OverLoCK-B | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth) | 160K |51.7 | [config](configs/overlock/upernet_overlock_base_ade20k_8xb2.py) |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_base_ade20k.pth) |
## 4. Train
To train ``OverLoCK-T + UperNet`` model on ADE20K dataset with 8 gpus (single node), run:
```
bash scripts/dist_train.sh configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py 8
```
## 5. Validation
To evaluate ``OverLoCK-T + UperNet`` model on ADE20K dataset, run:
```
bash scripts/dist_test.sh configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py path-to-checkpoint 8 --eval mIoU
```
## Citation
If you find this project useful for your research, please consider citing:
```
@inproceedings{lou2025overlock,
title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},
author={Lou, Meng and Yu, Yizhou},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
pages={128--138},
year={2025}
}
```
================================================
FILE: segmentation/scripts/dist_test.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=$((RANDOM+10000))
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
torchrun --nproc_per_node=$GPUS --master_port=$PORT test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
================================================
FILE: segmentation/scripts/dist_train.sh
================================================
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=$((RANDOM+10000))
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
torchrun --nproc_per_node=$GPUS --master_port=$PORT train.py $CONFIG --launcher pytorch ${@:3}
================================================
FILE: segmentation/test.py
================================================
import warnings
import argparse
import os
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmcv.utils import DictAction
from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
import models
import mmseg_custom
def parse_args():
parser = argparse.ArgumentParser(
description='mmseg test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--aug-test', action='store_true', help='Use Flip and Multi scale aug')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
' for generic datasets, and "cityscapes" for Cityscapes')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu_collect is not specified')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.aug_test:
# hard code index
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
try:
model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE']
except:
warnings.warn("'CLASSES' and 'PALETTE' are not in the checkpoint.")
print(f'Config:\n{cfg.pretty_text}')
efficient_test = False
if args.eval_options is not None:
efficient_test = args.eval_options.get('efficient_test', False)
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test, args.opacity)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect, efficient_test)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
if __name__ == '__main__':
main()
================================================
FILE: segmentation/train.py
================================================
import argparse
import warnings
import copy
import os
import os.path as osp
import time
import mmcv
import torch
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash
from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger
import models
import mmseg_custom
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--drop-path',
default=-1,
type=float,
help='drop-path-rate of the backbone network')
parser.add_argument(
'--freeze-bn',
action='store_true',
default=False,
help='freeze the BN layer of the backbone model during training')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='mixed precision training')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
if args.load_from is not None:
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.drop_path >= 0:
try:
cfg.model.backbone.drop_path_rate = args.drop_path
except:
logger.info('drop_path is not defined in the config file')
if args.freeze_bn:
try:
cfg.model.backbone.freeze_bn = True
except:
logger.info('freeze_bn is not defined in the config file')
if args.amp:
loss_scale = 'dynamic'
if cfg.get('fp16', None) is None:
cfg.fp16 = dict()
else:
warnings.warn('fp16 has been defined in the config file')
cfg.optimizer_config.type = 'Fp16OptimizerHook'
cfg.optimizer_config.loss_scale = loss_scale
# init distributed env first, since logger depends on the dist info.
cfg.device = 'cuda'
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
logger.info(model)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmseg version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_segmentor(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
================================================
FILE: train.py
================================================
#!/usr/bin/env python3
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import datetime
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
import torch
import torchvision
import yaml
from timm.data import (AugMixDataset, FastCollateMixup, Mixup, create_dataset,
create_loader, resolve_data_config)
from timm.loss import *
from timm.models import (convert_splitbn_model, create_model, load_checkpoint,
model_parameters, resume_checkpoint, safe_model_name)
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import *
from timm.utils import ApexScaler, NativeScaler
from torch import nn
from torch.nn.parallel import DistributedDataParallel as NativeDDP
try:
from mmcv.runner import load_checkpoint as load_ckpt
except:
from mmengine.runner import load_checkpoint as load_ckpt
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
import models
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
def get_args_parser():
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
# config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
# parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training', add_help=False)
# Dataset parameters
parser.add_argument('--data-dir', metavar='DIR', help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default=None, type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--finetune', default=None, type=str, metavar='PATH',
help='Fine-tune model from this checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
parser.add_argument('--grad-checkpoint', action='store_true', default=False,
help='Using gradient checkpointing for saving GPU memory')
parser.add_argument('--ckpt-stg', default=[0, 0, 0, 0], type=int, nargs='+',
help='stage for using grad checkpoint')
parser.add_argument('--val-freq', default=1, type=int,
help='do evaluation per n epoch')
parser.add_argument('--val-start-epoch', default=0, type=int,
help='do evaluation per epoch after n-th epoch')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--aux-loss-ratio', type=float, default=0.4,
help='Aux loss weight')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--auto-lr', action='store_true',
default=False, help='auto scaling learning rate')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default="rand-m9-mstd0.5-inc1", metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--aug-repeats', type=int, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--no-sync-bn', action='store_true', default=False,
help='Disable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.99984,
help='decay factor for model weights moving average (default: 0.99996)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('--checkpoint-hist', type=int, default=5, metavar='N',
help='number of checkpoints to keep')
parser.add_argument('--checkpoint-freq', type=int, default=50, metavar='N',
help='freq of saving checkpoints')
parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 8)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--compile', action='store_true', default=False,
help='Use torch.compile to accelerate training')
parser.add_argument('--debug-loss', action='store_true', default=False,
help='Use Anomaly Detection')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--no-pin-mem', action='store_true', default=False,
help='Disable Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
# parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
parser.add_argument('--gpu-limit', default=1, type=float)
parser.add_argument('--local-rank', default=0, type=int)
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def main(args):
# Cache the args as a text string to save them in the output dir later
args.device_name = torch.cuda.get_device_name()
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
setup_default_logging()
# args, args_text = _parse_args()
if args.local_rank == 0:
_logger.info(args_text)
if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
# args.device = 'cuda:0'
# args.world_size = 1
# args.rank = 0 # global rank
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.distributed = True
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
# args.device = 'cuda:%d' % args.local_rank
elif 'SLURM_PROCID' in os.environ and 'WORLD_SIZE' in os.environ:
args.distributed = True
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
if args.distributed:
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# args.world_size = torch.distributed.get_world_size()
# args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
args.device = 'cuda'
args.world_size = 1
args.rank = 0 # global rank
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
if args.gpu_limit < 1:
torch.cuda.set_per_process_memory_fraction(args.gpu_limit)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native PyTorch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
random_seed(args.seed, args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
# img_size=args.input_size,
use_checkpoint=args.ckpt_stg if args.grad_checkpoint else [0] * 4,
)
if args.finetune:
load_ckpt(model, args.finetune)
if args.num_classes is None:
assert hasattr(
model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
# FIXME handle model default vs config num_classes more elegantly
args.num_classes = model.num_classes
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
if args.local_rank == 0:
_logger.info(model)
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
model_str = str(model)
# move model to GPU, enable channels last layout if set
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and not args.no_sync_bn:
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if args.auto_lr:
args.lr = (args.batch_size * args.world_size / 1024) * 1e-3
args_text = yaml.load(args_text, Loader=yaml.FullLoader)
args_text['lr'] = args.lr
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
if args.local_rank == 0:
_logger.info(f'Initail learning rate: {args.lr}.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info(
'Using native PyTorch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
if args.grad_checkpoint and args.local_rank == 0:
_logger.info(
f'Using gradient checkpointing, checkpoint stage: {args.ckpt_stg}.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# model_ema_decay = args.model_ema_decay ** (args.batch_size * args.world_size / 512.0)
model_ema = ModelEmaV2(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
model = ApexDDP(model, delay_allreduce=True)
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
else:
model = NativeDDP(model, device_ids=[args.gpu], broadcast_buffers=not args.no_ddp_bb)
if args.local_rank == 0:
_logger.info("Using native PyTorch DistributedDataParallel.")
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
_logger.info("Creating Dataset ...")
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
# collate conflict (need to support deinterleaving in collate mixup)
assert not num_aug_splits
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
_logger.info("Creating Dataloader ...")
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=not args.no_pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=not args.no_pin_mem,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(
num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
target_threshold=args.bce_target_thresh)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
else:
# train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
train_loss_fn = nn.CrossEntropyLoss(label_smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
if args.model_ema:
ema_save_tag = 'ema'
eval_metric_ema = f'{args.eval_metric}_{ema_save_tag}'
best_metric = None
best_metric_ema = None
best_epoch = None
best_epoch_ema = None
saver = None
saver_ema = None
output_dir = None
resume_mode = args.resume
if args.rank == 0:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(
args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
checkpoint_dir = f'{output_dir}/checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)
saver = CheckpointSaver(model=model, optimizer=optimizer,
args=args, model_ema=model_ema,
amp_scaler=loss_scaler, checkpoint_dir=checkpoint_dir,
recovery_dir=output_dir, decreasing=decreasing,
max_history=args.checkpoint_hist)
if model_ema is not None:
checkpoint_dir = f'{output_dir}/checkpoints_ema/'
os.makedirs(checkpoint_dir, exist_ok=True)
saver_ema = CheckpointSaver(model=model, optimizer=optimizer,
args=args, model_ema=model_ema,
checkpoint_prefix='checkpoint-ema',
amp_scaler=loss_scaler, checkpoint_dir=checkpoint_dir,
recovery_dir=output_dir, decreasing=decreasing,
max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
with open(os.path.join(output_dir, 'model.log'), 'w') as f:
f.write(model_str)
if args.compile:
_logger.info("Compiling model ...")
model = torch.compile(model)
_logger.info("Start Training.")
_logger.info('Scheduled epochs: {}'.format(num_epochs))
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch, model, loader_train,
optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler,
saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
eta_meter=AverageMeter())
if (epoch % args.val_freq == 0) or epoch > args.val_start_epoch or resume_mode:
resume_mode = False
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size,
args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn,
args, amp_autocast=amp_autocast)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size,
args.dist_bn == 'reduce')
ema_eval_metrics = validate(model_ema.module, loader_eval,
validate_loss_fn, args,
amp_autocast=amp_autocast,
log_suffix=f'_{ema_save_tag}')
eval_metrics.update(ema_eval_metrics)
else:
eval_metrics = {key: float('-inf') for key in eval_metrics}
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if output_dir is not None:
update_summary(epoch,
train_metrics,
eval_metrics,
os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if saver_ema is not None:
# save proper checkpoint with eval metric (ema)
save_metric = eval_metrics[eval_metric_ema]
best_metric_ema, best_epoch_ema = saver_ema.save_checkpoint(epoch, metric=save_metric)
if best_metric is not None:
if args.local_rank == 0:
_logger.info(f"Currently Best Accuracy: {best_metric:.2f} at Epoch {best_epoch}")
if best_metric_ema is not None:
_logger.info(f"Currently Best Accuracy (EMA): {best_metric_ema:.2f} at Epoch {best_epoch_ema}")
_logger.info('\n')
with open(os.path.join(output_dir, 'best-metric.json'), 'w') as f:
best_metric_info = {
eval_metric: best_metric,
'epoch': best_epoch,
}
if best_metric_ema is not None:
best_metric_info[eval_metric_ema] = best_metric_ema
best_metric_info['epoch_ema'] = best_epoch_ema
json.dump(best_metric_info, f, indent=4)
except KeyboardInterrupt:
pass
if best_metric is not None:
# result_info = f'*** Best metric ({eval_metric}): {best_metric} (epoch {best_epoch}) ***,
# *** Best metric (ema) ({eval_metric_ema}): {best_metric_ema} (epoch {best_epoch_ema}) ***'
best_metric_info = {
eval_metric: best_metric,
'epoch': best_epoch,
}
if best_metric_ema is not None:
best_metric_info[eval_metric_ema] = best_metric_ema
best_metric_info['epoch_ema'] = best_epoch_ema
# _logger.info(f'best_metric: {best_metric_info}')
print(f'best_metric: {best_metric_info}')
def train_one_epoch(epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None,
amp_autocast=suppress, loss_scaler=None,
model_ema=None, mixup_fn=None, eta_meter=AverageMeter()):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
losses_aux_m = AverageMeter()
is_ds = False
model.train()
# torch.cuda.empty_cache()
end = time.time()
end_epoch = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
if args.debug_loss:
detect_anomaly = torch.autograd.detect_anomaly
else:
detect_anomaly = suppress
with amp_autocast():
with detect_anomaly():
output = model(input)
if isinstance(output, (tuple, list, dict)):
is_ds = True
output_main = output['main']
output_aux = output['aux']
loss_main = loss_fn(output_main, target)
loss_aux = loss_fn(output_aux, target)
loss = loss_main + args.aux_loss_ratio * loss_aux
else:
loss = loss_fn(output, target)
if not args.distributed:
if is_ds:
losses_m.update(loss_main.item(), input.size(0))
losses_aux_m.update(loss_aux.item(), input.size(0))
else:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(
model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
model_parameters(
model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if eta_meter is not None:
eta_meter.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
if args.distributed:
if is_ds:
reduced_loss = reduce_tensor(loss_main.data, args.world_size)
reduced_loss_aux = reduce_tensor(loss_aux.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
losses_aux_m.update(reduced_loss_aux.item(), input.size(0))
else:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
eta_remaining_batches = (args.epochs + args.cooldown_epochs - epoch - 1) * len(loader) + len(loader) - batch_idx - 1
batch_time_avg = eta_meter.avg if eta_meter is not None else batch_time_m.avg
eta_remaining_time = eta_remaining_batches * batch_time_avg
eta_td = datetime.timedelta(seconds=eta_remaining_time)
eta_str = str(eta_td)
current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
gpu_memory = torch.cuda.max_memory_allocated()
if args.local_rank == 0:
base_info = (
'[{}] Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'LR: {lr:.3e} '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
)
aux_info = (
'Loss (Aux): {loss_aux.val:#.4g} ({loss_aux.avg:#.3g}) '
)
other_info = (
'ETA: {eta} '
'Memory: {gpu_memory:d} '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
)
if is_ds:
base_info += aux_info
base_info += other_info
_logger.info(
base_info.format(
current_time,
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
loss_aux=losses_aux_m if is_ds else None,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m,
eta=eta_str.split('.')[0],
gpu_memory=int(gpu_memory/(1024**2))
)
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
epoch_time = time.time() - end_epoch
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
train_info_dict = {
'epoch_time(min)': round(epoch_time / 60, 1),
'lr': lr,
'loss': losses_m.avg,
}
if is_ds:
train_info_dict['loss_aux'] = losses_aux_m.avg
return OrderedDict(train_info_dict)
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
loss = loss_fn(output, target)
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(
0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([(f'loss{log_suffix}', losses_m.avg),
(f'top1{log_suffix}', top1_m.avg),
(f'top5{log_suffix}', top5_m.avg)])
return metrics
if __name__ == '__main__':
torch.cuda.empty_cache()
args = get_args_parser().parse_args()
main(args)
================================================
FILE: validate.py
================================================
#!/usr/bin/env python3
""" ImageNet Validation Script
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import os
import csv
import glob
import time
import torch
import logging
import argparse
from tqdm import tqdm
from torch import nn
from contextlib import suppress
from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
import models
has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=25, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='enable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=True,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--results-file', default=None, type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.legacy_jit:
set_jit_legacy()
# create model
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda()
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = {int(line.rstrip()) for line in f}
valid_labels = [i in valid_labels for i in range(args.num_classes)]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)
pbar = tqdm(total=len(dataset))
end = time.time()
for input, target in loader:
if args.no_prefetcher:
target = target.cuda(non_blocking=args.pin_mem)
input = input.cuda(non_blocking=args.pin_mem)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
pbar.update(args.batch_size)
pbar.close()
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'])
_logger.info(
dict(model=args.model,
loss=losses.avg,
top1=top1.avg,
top5=top5.avg,)
)
_logger.info(f"Accuracy@top-1 of the network on the {len(dataset)} test images: {top1a:.1f}%")
return results
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model)
model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs):
results_file = args.results_file or './results-all.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
start_batch_size = args.batch_size
for m, c in model_cfgs:
batch_size = start_batch_size
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
r = {}
while not r and batch_size >= args.num_gpu:
torch.cuda.empty_cache()
try:
args.batch_size = batch_size
print('Validating with batch size: %d' % args.batch_size)
r = validate(args)
except RuntimeError as e:
if batch_size <= args.num_gpu:
print("Validation failed with no ability to reduce batch size. Exiting.")
raise e
batch_size = max(batch_size // 2, args.num_gpu)
print("Validation failed, reducing batch size by 50%")
result.update(r)
if args.checkpoint:
result['checkpoint'] = args.checkpoint
results.append(result)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
if len(results):
write_results(results_file, results)
else:
validate(args)
def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()