[
  {
    "path": "LICENSE.md",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# [[CVPR 2025 Oral] OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels](https://arxiv.org/abs/2502.20087)\n\nThis is an official PyTorch implementation of \"[**OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels**](https://arxiv.org/abs/2502.20087)\".\n\n# Introduction\nTop-down attention plays a crucial role in the human vision system, wherein the brain initially obtains a rough overview of a scene to discover salient cues (i.e., overview first), followed by a more careful finer-grained examination (i.e., look closely next). However, modern ConvNets remain confined to a pyramid structure that successively downsamples the feature map for receptive field expansion, neglecting this crucial biomimetic principle. We present OverLoCK, the first pure ConvNet backbone architecture that explicitly incorporates a top-down attention mechanism. Unlike pyramid backbone networks, our design features a branched architecture with three synergistic sub-networks: 1) a Base-Net that encodes low/mid-level features; 2) a lightweight Overview-Net that generates dynamic top-down attention through coarse global context modeling (i.e., overview first); and 3) a robust Focus-Net that performs finer-grained perception guided by top-down attention (i.e., look closely next). To fully unleash the power of top-down attention, we further propose a novel context-mixing dynamic convolution (ContMix) that effectively models long-range dependencies while preserving inherent local inductive biases even when the input resolution increases, addressing critical limitations in existing convolutions. Our OverLoCK exhibits a notable performance improvement over existing methods.\n<center> \n<img src=\"images/img.jpg\" width=\"70%\" height=\"auto\">\n</center>\n\n# News\n- **Dec. 25, 2025**: **To improve inference speed and reduce memory consumption**, we provide **reparameterized versions of the OverLoCK models with pre-trained weights**. These variants achieve **identical performance to their original counterparts on ImageNet-1K evaluation**. However, if you further fine-tune these reparameterized models, they may yield slightly lower accuracy compared to the original versions. Please choose the model variant during fine-tuning based on memory and accuracy requirements on your side ([More Details](https://github.com/LMMMEng/OverLoCK/blob/81dd7b216e7aa66ff5a95b07021f299dc2d4d14b/models/overlock.py#L941C13-L941C14)).\n  \n- **May. 16, 2025**: A **plug-and-play implementation of the [ContMix Block](models/contmix.py)** is now available.\n\n# Image Classification\n\n## 1. Requirements\nWe highly suggest using our provided dependencies to ensure reproducibility:\n```\n# Environments:\ncuda==12.1\npython==3.10\n# Dependencies:\npip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121\npip install natten==0.17.1+torch230cu121 -f https://shi-labs.com/natten/wheels/\npip install timm==0.6.12\npip install mmengine==0.2.0\n```\n>💡 To accelerate training and inference, we utilize the efficient large-kernel convolution proposed in [RepLKNet](https://github.com/DingXiaoH/RepLKNet-pytorch#use-our-efficient-large-kernel-convolution-with-pytorch). Please follow this [**guideline**](https://github.com/VITA-Group/SLaK#installation) to install the ``depthwise_conv2d_implicit_gemm`` function.\n>\n>💡 If you encounter network issues during the installation of ``natten``, please download this [**package**](https://github.com/LMMMEng/OverLoCK/releases/download/v1/natten-0.17.1+torch230cu121-cp310-cp310-linux_x86_64.whl) and install it locally.\n\n\n## 2. Data Preparation\nPrepare [ImageNet](https://image-net.org/) with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).\n\n```\n│imagenet/\n├──train/\n│  ├── n01440764\n│  │   ├── n01440764_10026.JPEG\n│  │   ├── n01440764_10027.JPEG\n│  │   ├── ......\n│  ├── ......\n├──val/\n│  ├── n01440764\n│  │   ├── ILSVRC2012_val_00000293.JPEG\n│  │   ├── ILSVRC2012_val_00002138.JPEG\n│  │   ├── ......\n│  ├── ......\n```\n\n## 3. Main Results on ImageNet-1K with Pretrained Models\n\n| Models      | Input Size | FLOPs (G) | Params (M) | Top-1 (%) | Download |\n|:-----------:|:----------:|:---------:|:----------:|:----------:|:----------:|\n| OverLoCK-XT | 224x224    | 2.6       | 16       | 82.7       | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth)     |\n| OverLoCK-T | 224x224    | 5.5       | 33      | 84.2       | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth)     |\n| OverLoCK-S | 224x224    | 9.7      | 56       | 84.8       | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth)     |\n| OverLoCK-B | 224x224    | 16.7       | 95       | 85.1       | [model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth)     |\n\n## 4. Train\nTo train ```OverLoCK``` models on ImageNet-1K with 8 gpus (single node), run:\n```\nbash scripts/train_xt_model.sh # train OverLoCK-XT\nbash scripts/train_t_model.sh  # train OverLoCK-T\nbash scripts/train_s_model.sh  # train OverLoCK-S\nbash scripts/train_b_model.sh  # train OverLoCK-B\n```  \n> 💡If you encounter NaN loss, please delete ``--native-amp`` to disable AMP training and resume the checkpoint before the NaN loss occurred.\n>   \n> 💡If your **GPU memory** is insufficient during training, you can enable gradient checkpointing by adding the following arguments: ``--grad-checkpoint --ckpt-stg 4 0 0 0``. If you're still experiencing memory issues, you can increase these values, but be aware that this may slow down training speed.\n\n## 5. Validation\nTo evaluate ```OverLoCK``` on ImageNet-1K, run:\n```\nMODEL=overlock_xt # overlock_{xt, t, s, b}\npython3 validate.py \\\n/path/to/imagenet \\\n--model $MODEL -b 128 \\\n--pretrained # or --checkpoint /path/to/checkpoint \n```\n>💡 To accelerate inference speed, OverLoCK utilizes [**Structural Re-parameterization**](https://github.com/AILab-CVC/UniRepLKNet/tree/main). Please refer to [**here**](https://github.com/LMMMEng/OverLoCK/blob/540bf6ed9cca99eab78fc8ab935b71f2a4aa2a2c/models/overlock.py#L945) for a simple usage instruction.\n\n# Citation\nIf you find this project useful for your research, please consider citing:\n```\n@inproceedings{lou2025overlock,\n  title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},\n  author={Lou, Meng and Yu, Yizhou},\n  booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n  pages={128--138},\n  year={2025}\n}\n```\n\n# Dense Predictions\n[Object Detection](detection)  \n[Semantic Segmentation](segmentation)   \n\n# Acknowledgment\nOur implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.  \n> [timm](https://github.com/rwightman/pytorch-image-models), [natten](https://github.com/SHI-Labs/NATTEN), [unireplknet](https://github.com/AILab-CVC/UniRepLKNet), [mmcv](https://github.com/open-mmlab/mmcv), [mmdet](https://github.com/open-mmlab/mmdetection), [mmseg](https://github.com/open-mmlab/mmsegmentation)  \n\n# Contact\nIf you have any questions, please feel free to [create issues❓](https://github.com/LMMMEng/OverLoCK/issues) or [contact me 📧](lmzmm.0921@gmail.com).\n"
  },
  {
    "path": "detection/configs/_base_/datasets/coco_detection.py",
    "content": "# dataset settings\ndataset_type = 'CocoDataset'\ndata_root = '/grp01/cs_yzyu/dataset/coco2017/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_train2017.json',\n        img_prefix=data_root + 'train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(interval=1, metric='bbox', classwise=True)"
  },
  {
    "path": "detection/configs/_base_/datasets/coco_instance.py",
    "content": "# dataset settings\ndataset_type = 'CocoDataset'\ndata_root = '/grp01/cs_yzyu/dataset/coco2017/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(1333, 800),\n        flip=False,\n        transforms=[\n            dict(type='Resize', keep_ratio=True),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='Pad', size_divisor=32),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=2,\n    workers_per_gpu=4,\n    train=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_train2017.json',\n        img_prefix=data_root + 'train2017/',\n        pipeline=train_pipeline),\n    val=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        ann_file=data_root + 'annotations/instances_val2017.json',\n        img_prefix=data_root + 'val2017/',\n        pipeline=test_pipeline))\nevaluation = dict(metric=['bbox', 'segm'], classwise=True)\n"
  },
  {
    "path": "detection/configs/_base_/default_runtime.py",
    "content": "checkpoint_config = dict(interval=1)\n# yapf:disable\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook'),\n        # dict(type='TensorboardLoggerHook')\n    ])\n# yapf:enable\ncustom_hooks = [dict(type='NumClassCheckHook')]\n\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\n"
  },
  {
    "path": "detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='CascadeRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ],\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))"
  },
  {
    "path": "detection/configs/_base_/models/cascade_mask_rcnn_r50_fpn_crowdhuman.py",
    "content": "# model settings\nmodel = dict(\n    type='CascadeRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ],),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                mask_size=28,\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n"
  },
  {
    "path": "detection/configs/_base_/models/cascade_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='CascadeRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n    roi_head=dict(\n        type='CascadeRoIHead',\n        num_stages=3,\n        stage_loss_weights=[1, 0.5, 0.25],\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=[\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.1, 0.1, 0.2, 0.2]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.05, 0.05, 0.1, 0.1]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,\n                               loss_weight=1.0)),\n            dict(\n                type='Shared2FCBBoxHead',\n                in_channels=256,\n                fc_out_channels=1024,\n                roi_feat_size=7,\n                num_classes=80,\n                bbox_coder=dict(\n                    type='DeltaXYWHBBoxCoder',\n                    target_means=[0., 0., 0., 0.],\n                    target_stds=[0.033, 0.033, 0.067, 0.067]),\n                reg_class_agnostic=True,\n                loss_cls=dict(\n                    type='CrossEntropyLoss',\n                    use_sigmoid=False,\n                    loss_weight=1.0),\n                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))\n        ]),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=[\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.5,\n                    neg_iou_thr=0.5,\n                    min_pos_iou=0.5,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.6,\n                    neg_iou_thr=0.6,\n                    min_pos_iou=0.6,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                pos_weight=-1,\n                debug=False),\n            dict(\n                assigner=dict(\n                    type='MaxIoUAssigner',\n                    pos_iou_thr=0.7,\n                    neg_iou_thr=0.7,\n                    min_pos_iou=0.7,\n                    match_low_quality=False,\n                    ignore_iof_thr=-1),\n                sampler=dict(\n                    type='RandomSampler',\n                    num=512,\n                    pos_fraction=0.25,\n                    neg_pos_ub=-1,\n                    add_gt_as_proposals=True),\n                pos_weight=-1,\n                debug=False)\n        ]),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100)))\n"
  },
  {
    "path": "detection/configs/_base_/models/fast_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='FastRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=False,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100)))\n"
  },
  {
    "path": "detection/configs/_base_/models/faster_rcnn_r50_caffe_c4.py",
    "content": "# model settings\nnorm_cfg = dict(type='BN', requires_grad=False)\nmodel = dict(\n    type='FasterRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=3,\n        strides=(1, 2, 2),\n        dilations=(1, 1, 1),\n        out_indices=(2, ),\n        frozen_stages=1,\n        norm_cfg=norm_cfg,\n        norm_eval=True,\n        style='caffe',\n        init_cfg=dict(\n            type='Pretrained',\n            checkpoint='open-mmlab://detectron2/resnet50_caffe')),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=1024,\n        feat_channels=1024,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[2, 4, 8, 16, 32],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[16]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        shared_head=dict(\n            type='ResLayer',\n            depth=50,\n            stage=3,\n            stride=2,\n            dilation=1,\n            style='caffe',\n            norm_cfg=norm_cfg,\n            norm_eval=True),\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=1024,\n            featmap_strides=[16]),\n        bbox_head=dict(\n            type='BBoxHead',\n            with_avg_pool=True,\n            roi_feat_size=7,\n            in_channels=2048,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=12000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=False,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=6000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100)))\n"
  },
  {
    "path": "detection/configs/_base_/models/faster_rcnn_r50_caffe_dc5.py",
    "content": "# model settings\nnorm_cfg = dict(type='BN', requires_grad=False)\nmodel = dict(\n    type='FasterRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        strides=(1, 2, 2, 1),\n        dilations=(1, 1, 1, 2),\n        out_indices=(3, ),\n        frozen_stages=1,\n        norm_cfg=norm_cfg,\n        norm_eval=True,\n        style='caffe',\n        init_cfg=dict(\n            type='Pretrained',\n            checkpoint='open-mmlab://detectron2/resnet50_caffe')),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=2048,\n        feat_channels=2048,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[2, 4, 8, 16, 32],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[16]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=2048,\n            featmap_strides=[16]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=2048,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=12000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=False,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms=dict(type='nms', iou_threshold=0.7),\n            nms_pre=6000,\n            max_per_img=1000,\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100)))\n"
  },
  {
    "path": "detection/configs/_base_/models/faster_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='FasterRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=-1,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=False,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100)\n        # soft-nms is also supported for rcnn testing\n        # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)\n    ))\n"
  },
  {
    "path": "detection/configs/_base_/models/mask_rcnn_convnext_fpn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\n# model settings\nmodel = dict(\n    type='MaskRCNN',\n    pretrained=None,\n    backbone=dict(\n        type='ConvNeXt',\n        in_chans=3,\n        depths=[3, 3, 9, 3],\n        dims=[96, 192, 384, 768],\n        drop_path_rate=0.2,\n        layer_scale_init_value=1e-6,\n        out_indices=[0, 1, 2, 3],\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=[128, 256, 512, 1024],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=-1,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            mask_size=28,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))"
  },
  {
    "path": "detection/configs/_base_/models/mask_rcnn_r50_caffe_c4.py",
    "content": "# model settings\nnorm_cfg = dict(type='BN', requires_grad=False)\nmodel = dict(\n    type='MaskRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=3,\n        strides=(1, 2, 2),\n        dilations=(1, 1, 1),\n        out_indices=(2, ),\n        frozen_stages=1,\n        norm_cfg=norm_cfg,\n        norm_eval=True,\n        style='caffe',\n        init_cfg=dict(\n            type='Pretrained',\n            checkpoint='open-mmlab://detectron2/resnet50_caffe')),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=1024,\n        feat_channels=1024,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[2, 4, 8, 16, 32],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[16]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        shared_head=dict(\n            type='ResLayer',\n            depth=50,\n            stage=3,\n            stride=2,\n            dilation=1,\n            style='caffe',\n            norm_cfg=norm_cfg,\n            norm_eval=True),\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=1024,\n            featmap_strides=[16]),\n        bbox_head=dict(\n            type='BBoxHead',\n            with_avg_pool=True,\n            roi_feat_size=7,\n            in_channels=2048,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n        mask_roi_extractor=None,\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=0,\n            in_channels=2048,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=12000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=False,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            mask_size=14,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=6000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            max_per_img=1000,\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n"
  },
  {
    "path": "detection/configs/_base_/models/mask_rcnn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='MaskRCNN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    roi_head=dict(\n        type='StandardRoIHead',\n        bbox_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        bbox_head=dict(\n            type='Shared2FCBBoxHead',\n            in_channels=256,\n            fc_out_channels=1024,\n            roi_feat_size=7,\n            num_classes=80,\n            bbox_coder=dict(\n                type='DeltaXYWHBBoxCoder',\n                target_means=[0., 0., 0., 0.],\n                target_stds=[0.1, 0.1, 0.2, 0.2]),\n            reg_class_agnostic=False,\n            loss_cls=dict(\n                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n        mask_roi_extractor=dict(\n            type='SingleRoIExtractor',\n            roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),\n            out_channels=256,\n            featmap_strides=[4, 8, 16, 32]),\n        mask_head=dict(\n            type='FCNMaskHead',\n            num_convs=4,\n            in_channels=256,\n            conv_out_channels=256,\n            num_classes=80,\n            loss_mask=dict(\n                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=-1,\n            pos_weight=-1,\n            debug=False),\n        rpn_proposal=dict(\n            nms_pre=2000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.5,\n                neg_iou_thr=0.5,\n                min_pos_iou=0.5,\n                match_low_quality=True,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=512,\n                pos_fraction=0.25,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=True),\n            mask_size=28,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=1000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0),\n        rcnn=dict(\n            score_thr=0.05,\n            nms=dict(type='nms', iou_threshold=0.5),\n            max_per_img=100,\n            mask_thr_binary=0.5)))\n"
  },
  {
    "path": "detection/configs/_base_/models/retinanet_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='RetinaNet',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        start_level=1,\n        add_extra_convs='on_input',\n        num_outs=5),\n    bbox_head=dict(\n        type='RetinaHead',\n        num_classes=80,\n        in_channels=256,\n        stacked_convs=4,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            octave_base_scale=4,\n            scales_per_octave=3,\n            ratios=[0.5, 1.0, 2.0],\n            strides=[8, 16, 32, 64, 128]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(\n        assigner=dict(\n            type='MaxIoUAssigner',\n            pos_iou_thr=0.5,\n            neg_iou_thr=0.4,\n            min_pos_iou=0,\n            ignore_iof_thr=-1),\n        allowed_border=-1,\n        pos_weight=-1,\n        debug=False),\n    test_cfg=dict(\n        nms_pre=1000,\n        min_bbox_size=0,\n        score_thr=0.05,\n        nms=dict(type='nms', iou_threshold=0.5),\n        max_per_img=100))\n"
  },
  {
    "path": "detection/configs/_base_/models/rpn_r50_caffe_c4.py",
    "content": "# model settings\nmodel = dict(\n    type='RPN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=3,\n        strides=(1, 2, 2),\n        dilations=(1, 1, 1),\n        out_indices=(2, ),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        init_cfg=dict(\n            type='Pretrained',\n            checkpoint='open-mmlab://detectron2/resnet50_caffe')),\n    neck=None,\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=1024,\n        feat_channels=1024,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[2, 4, 8, 16, 32],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[16]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=12000,\n            max_per_img=2000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0)))\n"
  },
  {
    "path": "detection/configs/_base_/models/rpn_r50_fpn.py",
    "content": "# model settings\nmodel = dict(\n    type='RPN',\n    backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=5),\n    rpn_head=dict(\n        type='RPNHead',\n        in_channels=256,\n        feat_channels=256,\n        anchor_generator=dict(\n            type='AnchorGenerator',\n            scales=[8],\n            ratios=[0.5, 1.0, 2.0],\n            strides=[4, 8, 16, 32, 64]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[1.0, 1.0, 1.0, 1.0]),\n        loss_cls=dict(\n            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(\n        rpn=dict(\n            assigner=dict(\n                type='MaxIoUAssigner',\n                pos_iou_thr=0.7,\n                neg_iou_thr=0.3,\n                min_pos_iou=0.3,\n                ignore_iof_thr=-1),\n            sampler=dict(\n                type='RandomSampler',\n                num=256,\n                pos_fraction=0.5,\n                neg_pos_ub=-1,\n                add_gt_as_proposals=False),\n            allowed_border=0,\n            pos_weight=-1,\n            debug=False)),\n    test_cfg=dict(\n        rpn=dict(\n            nms_pre=2000,\n            max_per_img=1000,\n            nms=dict(type='nms', iou_threshold=0.7),\n            min_bbox_size=0)))\n"
  },
  {
    "path": "detection/configs/_base_/models/ssd300.py",
    "content": "# model settings\ninput_size = 300\nmodel = dict(\n    type='SingleStageDetector',\n    backbone=dict(\n        type='SSDVGG',\n        depth=16,\n        with_last_pool=False,\n        ceil_mode=True,\n        out_indices=(3, 4),\n        out_feature_indices=(22, 34),\n        init_cfg=dict(\n            type='Pretrained', checkpoint='open-mmlab://vgg16_caffe')),\n    neck=dict(\n        type='SSDNeck',\n        in_channels=(512, 1024),\n        out_channels=(512, 1024, 512, 256, 256, 256),\n        level_strides=(2, 2, 1, 1),\n        level_paddings=(1, 1, 0, 0),\n        l2_norm_scale=20),\n    bbox_head=dict(\n        type='SSDHead',\n        in_channels=(512, 1024, 512, 256, 256, 256),\n        num_classes=80,\n        anchor_generator=dict(\n            type='SSDAnchorGenerator',\n            scale_major=False,\n            input_size=input_size,\n            basesize_ratio_range=(0.15, 0.9),\n            strides=[8, 16, 32, 64, 100, 300],\n            ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]),\n        bbox_coder=dict(\n            type='DeltaXYWHBBoxCoder',\n            target_means=[.0, .0, .0, .0],\n            target_stds=[0.1, 0.1, 0.2, 0.2])),\n    # model training and testing settings\n    train_cfg=dict(\n        assigner=dict(\n            type='MaxIoUAssigner',\n            pos_iou_thr=0.5,\n            neg_iou_thr=0.5,\n            min_pos_iou=0.,\n            ignore_iof_thr=-1,\n            gt_max_assign_all=False),\n        smoothl1_beta=1.,\n        allowed_border=-1,\n        pos_weight=-1,\n        neg_pos_ratio=3,\n        debug=False),\n    test_cfg=dict(\n        nms_pre=1000,\n        nms=dict(type='nms', iou_threshold=0.45),\n        min_bbox_size=0,\n        score_thr=0.02,\n        max_per_img=200))\ncudnn_benchmark = True\n"
  },
  {
    "path": "detection/configs/_base_/schedules/schedule_1x.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=0.001,\n    step=[8, 11])\nrunner = dict(type='EpochBasedRunner', max_epochs=12)\n"
  },
  {
    "path": "detection/configs/_base_/schedules/schedule_3x.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=None)\n# learning policy\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=0.001,\n    step=[27, 33])\nrunner = dict(type='EpochBasedRunner', max_epochs=36)\n"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_1x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [80, 160, 528, 720]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_b',\n        pretrained=True,\n        drop_path_rate=0.6\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\n###########################################################################################################\n# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py\n# We follow uniformer's optimizer and lr schedule\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_3x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [80, 160, 528, 720]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_b',\n        pretrained=True,\n        drop_path_rate=0.6\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n# augmentation strategy originates from DETR / Sparse RCNN\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='AutoAugment',\n         policies=[\n             [\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                                 (736, 1333), (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True)\n             ],\n             [\n                 dict(type='Resize',\n                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True),\n                 dict(type='RandomCrop',\n                      crop_type='absolute_range',\n                      crop_size=(384, 600),\n                      allow_negative_crop=True),\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                                 (576, 1333), (608, 1333), (640, 1333),\n                                 (672, 1333), (704, 1333), (736, 1333),\n                                 (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      override=True,\n                      keep_ratio=True)\n             ]\n         ]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\n\n# We use 8 GPUs to train this model so that the total batch size was 16\ndata = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)\n\n# # AMP (faster but may meet nan loss) ->\n# fp16 = dict(loss_scale='dynamic')"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_1x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [64, 128, 448, 640]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_s',\n        pretrained=True,\n        drop_path_rate=0.4\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\n###########################################################################################################\n# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py\n# We follow uniformer's optimizer and lr schedule\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_3x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [64, 128, 448, 640]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_s',\n        pretrained=True,\n        drop_path_rate=0.4\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n# augmentation strategy originates from DETR / Sparse RCNN\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='AutoAugment',\n         policies=[\n             [\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                                 (736, 1333), (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True)\n             ],\n             [\n                 dict(type='Resize',\n                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True),\n                 dict(type='RandomCrop',\n                      crop_type='absolute_range',\n                      crop_size=(384, 600),\n                      allow_negative_crop=True),\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                                 (576, 1333), (608, 1333), (640, 1333),\n                                 (672, 1333), (704, 1333), (736, 1333),\n                                 (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      override=True,\n                      keep_ratio=True)\n             ]\n         ]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\n\n# We use 8 GPUs to train this model so that the total batch size was 16\ndata = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)\n\n# # AMP (faster but may meet nan loss) ->\n# fp16 = dict(loss_scale='dynamic')"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_1x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [64, 128, 384, 640]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_t',\n        pretrained=True,\n        drop_path_rate=0.2\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\n###########################################################################################################\n# https://github.com/Sense-X/UniFormer/blob/main/object_detection/exp/mask_rcnn_1x_hybrid_small/config.py\n# We follow uniformer's optimizer and lr schedule\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)"
  },
  {
    "path": "detection/configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_3x_coco.py",
    "content": "_base_ = [\n    '../_base_/models/mask_rcnn_r50_fpn.py',\n    '../_base_/datasets/coco_instance.py',\n    '../_base_/schedules/schedule_3x.py',\n    '../_base_/default_runtime.py'\n]\ndims = [64, 128, 384, 640]\nmodel = dict(\n    backbone=dict(\n        _delete_=True,\n        type='overlock_t',\n        pretrained=True,\n        drop_path_rate=0.2\n    ),\n    neck=dict(\n        type='FPN',\n        in_channels=dims,\n        out_channels=256,\n        num_outs=5))\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\n# augmentation strategy originates from DETR / Sparse RCNN\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),\n    dict(type='RandomFlip', flip_ratio=0.5),\n    dict(type='AutoAugment',\n         policies=[\n             [\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),\n                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),\n                                 (736, 1333), (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True)\n             ],\n             [\n                 dict(type='Resize',\n                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],\n                      multiscale_mode='value',\n                      keep_ratio=True),\n                 dict(type='RandomCrop',\n                      crop_type='absolute_range',\n                      crop_size=(384, 600),\n                      allow_negative_crop=True),\n                 dict(type='Resize',\n                      img_scale=[(480, 1333), (512, 1333), (544, 1333),\n                                 (576, 1333), (608, 1333), (640, 1333),\n                                 (672, 1333), (704, 1333), (736, 1333),\n                                 (768, 1333), (800, 1333)],\n                      multiscale_mode='value',\n                      override=True,\n                      keep_ratio=True)\n             ]\n         ]),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size_divisor=32),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),\n]\n\n# We use 8 GPUs to train this model so that the total batch size was 16\ndata = dict(samples_per_gpu=2, train=dict(pipeline=train_pipeline))\n\noptimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nevaluation = dict(save_best='auto')\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1, save_last=True)\n\n# # AMP (faster but may meet nan loss) ->\n# fp16 = dict(loss_scale='dynamic')"
  },
  {
    "path": "detection/models/__init__.py",
    "content": "from .overlock import *"
  },
  {
    "path": "detection/models/overlock.py",
    "content": "'''\nThis is an official implementation of OverLoCK model proposed in the paper: \nhttps://arxiv.org/abs/2502.20087\n'''\nimport torch\nimport timm\nimport torch.distributed\nimport torch.nn.functional as F\nfrom torch import nn\nfrom einops import rearrange, einsum\nfrom natten.functional import na2d_av\nfrom torch.utils.checkpoint import checkpoint\nfrom timm.models.layers import DropPath, to_2tuple\nfrom timm.models.registry import register_model\nfrom mmdet.models.builder import MODELS\nfrom mmdet.utils import get_root_logger\ntry:\n    from mmcv.runner import load_checkpoint\nexcept:\n    from mmengine.runner import load_checkpoint\n    \ndef get_conv2d(in_channels, \n               out_channels, \n               kernel_size, \n               stride, \n               padding, \n               dilation, \n               groups, \n               bias,\n               attempt_use_lk_impl=True):\n    \n    kernel_size = to_2tuple(kernel_size)\n    if padding is None:\n        padding = (kernel_size[0] // 2, kernel_size[1] // 2)\n    else:\n        padding = to_2tuple(padding)\n    need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)\n\n    if attempt_use_lk_impl and need_large_impl:\n        print('---------------- trying to import iGEMM implementation for large-kernel conv')\n        try:\n            from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM\n            print('---------------- found iGEMM implementation ')\n        except:\n            DepthWiseConv2dImplicitGEMM = None\n            print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')\n        if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \\\n                and out_channels == groups and stride == 1 and dilation == 1:\n            print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')\n            return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)\n    \n    return nn.Conv2d(in_channels, out_channels, \n                     kernel_size=kernel_size, \n                     stride=stride,\n                     padding=padding, \n                     dilation=dilation, \n                     groups=groups, \n                     bias=bias)\n\n\ndef get_bn(dim, use_sync_bn=False):\n    if use_sync_bn:\n        return nn.SyncBatchNorm(dim)\n    else:\n        return nn.BatchNorm2d(dim)\n\n\ndef fuse_bn(conv, bn):\n    conv_bias = 0 if conv.bias is None else conv.bias\n    std = (bn.running_var + bn.eps).sqrt()\n    return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std\n\ndef convert_dilated_to_nondilated(kernel, dilate_rate):\n    identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)\n    if kernel.size(1) == 1:\n        #   This is a DW kernel\n        dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)\n        return dilated\n    else:\n        #   This is a dense or group-wise (but not DW) kernel\n        slices = []\n        for i in range(kernel.size(1)):\n            dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)\n            slices.append(dilated)\n        return torch.cat(slices, dim=1)\n\ndef merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):\n    large_k = large_kernel.size(2)\n    dilated_k = dilated_kernel.size(2)\n    equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1\n    equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)\n    rows_to_pad = large_k // 2 - equivalent_kernel_size // 2\n    merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)\n    return merged_kernel\n\n\ndef stem(in_chans=3, embed_dim=96):\n    return nn.Sequential(\n        nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim),\n        nn.GELU(),\n        nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim)\n    )\n\n\ndef downsample(in_dim, out_dim):\n    return nn.Sequential(\n        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(out_dim),\n    )        \n\n\nclass SEModule(nn.Module):\n    def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):\n        super().__init__()\n        inner_dim = max(16, dim // red)\n        self.proj = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(dim, inner_dim, kernel_size=1),\n            inner_act(),\n            nn.Conv2d(inner_dim, dim, kernel_size=1),\n            out_act(),\n        )\n        \n    def forward(self, x):\n        x = x * self.proj(x)\n        return x\n\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_value=1e-5):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value, \n                                   requires_grad=True)\n        self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)\n\n    def forward(self, x):\n        x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])\n        return x\n\n        \nclass LayerNorm2d(nn.LayerNorm):\n    def __init__(self, dim):\n        super().__init__(normalized_shape=dim, eps=1e-6)\n    \n    def forward(self, x):\n        x = rearrange(x, 'b c h w -> b h w c')\n        x = super().forward(x)\n        x = rearrange(x, 'b h w c -> b c h w')\n        return x.contiguous()\n\n\nclass GRN(nn.Module):\n    \"\"\" GRN (Global Response Normalization) layer\n    Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)\n    This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)\n    We assume the inputs to this layer are (N, C, H, W)\n    \"\"\"\n    def __init__(self, dim, use_bias=True):\n        super().__init__()\n        self.use_bias = use_bias\n        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))\n        if self.use_bias:\n            self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))\n\n    def forward(self, x):\n        Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)\n        Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)\n        if self.use_bias:\n            return (self.gamma * Nx + 1) * x + self.beta\n        else:\n            return (self.gamma * Nx + 1) * x\n    \n\n\nclass DilatedReparamBlock(nn.Module):\n    \"\"\"\n    Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)\n    We assume the inputs to this block are (N, C, H, W)\n    \"\"\"\n    def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):\n        super().__init__()\n        self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,\n                                    padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,\n                                    attempt_use_lk_impl=attempt_use_lk_impl)\n        self.attempt_use_lk_impl = attempt_use_lk_impl\n\n        #   Default settings. We did not tune them carefully. Different settings may work better.\n        if kernel_size == 19:\n            self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]\n            self.dilates = [1, 1, 1, 2, 4, 5, 7]\n        elif kernel_size == 17:\n            self.kernel_sizes = [5, 7, 9, 3, 3, 3]\n            self.dilates = [1, 1, 2, 4, 5, 7]\n        elif kernel_size == 15:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 5, 7]\n        elif kernel_size == 13:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 11:\n            self.kernel_sizes = [5, 7, 5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 9:\n            self.kernel_sizes = [5, 7, 5, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4]\n        elif kernel_size == 7:\n            self.kernel_sizes = [5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3]\n        elif kernel_size == 5:\n            self.kernel_sizes = [3, 3]\n            self.dilates = [1, 2]\n        else:\n            raise ValueError('Dilated Reparam Block requires kernel_size >= 5')\n\n        if not deploy:\n            self.origin_bn = get_bn(channels, use_sync_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__setattr__('dil_conv_k{}_{}'.format(k, r),\n                                 nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,\n                                           padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,\n                                           bias=False))\n                self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))\n\n    def forward(self, x):\n        if not hasattr(self, 'origin_bn'): # deploy mode\n            return self.lk_origin(x)\n        out = self.origin_bn(self.lk_origin(x))\n        for k, r in zip(self.kernel_sizes, self.dilates):\n            conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n            bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n            out = out + bn(conv(x))\n        return out\n\n    def merge_dilated_branches(self):\n        if hasattr(self, 'origin_bn'):\n            origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n                bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n                branch_k, branch_b = fuse_bn(conv, bn)\n                origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)\n                origin_b += branch_b\n            merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,\n                                    padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,\n                                    attempt_use_lk_impl=self.attempt_use_lk_impl)\n            merged_conv.weight.data = origin_k\n            merged_conv.bias.data = origin_b\n            self.lk_origin = merged_conv\n            self.__delattr__('origin_bn')\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__delattr__('dil_conv_k{}_{}'.format(k, r))\n                self.__delattr__('dil_bn_k{}_{}'.format(k, r))\n       \n\nclass CTXDownsample(nn.Module):\n    def __init__(self, dim, h_dim):\n        super().__init__()\n        \n        self.x_proj = nn.Sequential(\n            nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim)\n        )\n        self.h_proj = nn.Sequential(\n            nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim//4)\n        )\n\n    def forward(self, x, ctx):\n        x = self.x_proj(x)\n        ctx = self.h_proj(ctx)\n        return (x, ctx)\n\n\nclass ResDWConv(nn.Conv2d):\n    '''\n    Depthwise convolution with residual connection\n    '''\n    def __init__(self, dim, kernel_size=3):\n        super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)\n    \n    def forward(self, x):\n        x = x + super().forward(x)\n        return x\n\n\nclass RepConvBlock(nn.Module):\n\n    def __init__(self, \n                 dim=64,\n                 kernel_size=7,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False):\n        super().__init__()\n        \n        self.res_scale = res_scale\n        self.use_checkpoint = use_checkpoint\n        \n        mlp_dim = int(dim*mlp_ratio)\n        \n        self.dwconv = ResDWConv(dim, kernel_size=3)\n    \n        self.proj = nn.Sequential(\n            norm_layer(dim),\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n            SEModule(dim),\n            nn.Conv2d(dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, dim, kernel_size=1),\n            DropPath(drop_path) if drop_path > 0 else nn.Identity(),\n        )\n\n        self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        \n    def forward_features(self, x):\n        \n        x = self.dwconv(x)\n        \n        if self.res_scale:\n            x = self.ls(x) + self.proj(x)\n        else:\n            drop_path = self.proj[-1]\n            x = x + drop_path(self.ls(self.proj[:-1](x)))\n\n        return x\n    \n    def forward(self, x):\n        \n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self.forward_features, x, use_reentrant=False)\n        else:\n            x = self.forward_features(x)\n        \n        return x\n\n\nclass DynamicConvBlock(nn.Module):\n    def __init__(self,\n                 dim=64,\n                 ctx_dim=32,\n                 kernel_size=7,\n                 smk_size=5,\n                 num_heads=2,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 is_first=False,\n                 is_last=False,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False,\n                 **kwargs):\n        \n        super().__init__()\n        \n        ctx_dim = ctx_dim // 4\n        out_dim = dim + ctx_dim\n        mlp_dim = int(dim*mlp_ratio)\n        self.kernel_size = kernel_size\n        self.res_scale = res_scale\n        self.use_gemm = use_gemm\n        self.smk_size = smk_size\n        self.num_heads = num_heads * 2\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n        self.is_first = is_first\n        self.is_last = is_last\n        self.use_checkpoint = use_checkpoint\n\n        if not is_first:\n            self.x_scale = LayerScale(ctx_dim, init_value=1)\n            self.h_scale = LayerScale(ctx_dim, init_value=1)\n        \n        self.dwconv1 = ResDWConv(out_dim, kernel_size=3)\n        self.norm1 = norm_layer(out_dim)\n        \n        self.fusion = nn.Sequential(\n            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),\n            nn.BatchNorm2d(out_dim),\n            nn.GELU(),\n            nn.Conv2d(out_dim, dim, kernel_size=1),\n            GRN(dim),\n        )\n        \n        self.weight_query = nn.Sequential(\n            nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n         \n        self.weight_key = nn.Sequential(\n            nn.AdaptiveAvgPool2d(7),\n            nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n        \n        self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)\n        \n        self.dyconv_proj = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.lepe = nn.Sequential(\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.se_layer = SEModule(dim)\n        \n        self.gate = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n            nn.SiLU(),\n        )\n\n        self.proj = nn.Sequential(\n            nn.BatchNorm2d(dim),\n            nn.Conv2d(dim, out_dim, kernel_size=1),\n        )\n        \n        self.dwconv2 = ResDWConv(out_dim, kernel_size=3)\n        self.norm2 = norm_layer(out_dim)\n        \n        self.mlp = nn.Sequential(\n            nn.Conv2d(out_dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, out_dim, kernel_size=1),\n        )\n        \n        self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()\n        \n        self.get_rpb()\n\n\n    def get_rpb(self):\n        self.rpb_size1 = 2 * self.smk_size - 1\n        self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))\n        self.rpb_size2 = 2 * self.kernel_size - 1\n        self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))\n        nn.init.zeros_(self.rpb1)\n        nn.init.zeros_(self.rpb2)\n    \n        \n    @torch.no_grad()\n    def generate_idx(self, kernel_size):\n        rpb_size = 2 * kernel_size - 1\n        idx_h = torch.arange(0, kernel_size)\n        idx_w = torch.arange(0, kernel_size)\n        idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)\n        return (idx_h, idx_w, idx_k)\n    \n\n    def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):\n        \"\"\"\n        RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3\n        \"\"\"\n        num_repeat_h = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_w = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_h[kernel_size//2] = height - (kernel_size-1)\n        num_repeat_w[kernel_size//2] = width - (kernel_size-1)\n        bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)\n        bias_idx = bias_hw.unsqueeze(-1) + idx_k\n        bias_idx = bias_idx.reshape(-1, int(kernel_size**2))\n        bias_idx = torch.flip(bias_idx, [0])\n        rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]\n        rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))\n        return attn + rpb\n    \n\n    def _forward_inner(self, x, h_x, h_r):\n        input_resoltion = x.shape[2:]\n        B, C, H, W = x.shape\n        B, C_h, H_h, W_h = h_x.shape\n        \n        if not self.is_first:\n            h_x = self.x_scale(h_x) + self.h_scale(h_r)\n\n        x_f = torch.cat([x, h_x], dim=1)\n        x_f = self.dwconv1(x_f)\n        identity = x_f\n        x_f = self.norm1(x_f)\n        x = self.fusion(x_f)\n        gate = self.gate(x)\n        lepe = self.lepe(x)\n        \n        is_pad = False\n        if min(H, W) < self.kernel_size:\n            is_pad = True\n            if H < W:\n                size = (self.kernel_size, int(self.kernel_size / H * W))\n            else:\n                size = (int(self.kernel_size / W * H), self.kernel_size)\n                \n            x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)\n            x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)\n            H, W = size\n\n        query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)\n        query = self.weight_query(query) * self.scale\n        key = self.weight_key(key)\n        query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        weight = einsum(query, key, 'b g c n, b g c l -> b g n l')\n        weight = rearrange(weight, 'b g n l -> b l g n').contiguous()\n        weight = self.weight_proj(weight)\n        weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)\n\n        attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)\n        rpb1_idx = self.generate_idx(self.smk_size)\n        rpb2_idx = self.generate_idx(self.kernel_size)\n        attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)\n        attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)\n        attn1 = torch.softmax(attn1, dim=-1)\n        attn2 = torch.softmax(attn2, dim=-1)\n        value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)\n\n        x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)\n        x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)\n\n        x = torch.cat([x1, x2], dim=1)\n        x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)\n        \n        if is_pad:\n            x = F.adaptive_avg_pool2d(x, input_resoltion)\n        \n        x = self.dyconv_proj(x)\n\n        x = x + lepe\n        x = self.se_layer(x)\n\n        x = gate * x\n        x = self.proj(x)\n\n        if self.res_scale:\n            x = self.ls1(identity) + self.drop_path(x)\n        else:\n            x = identity + self.drop_path(self.ls1(x))\n        \n        x = self.dwconv2(x)\n         \n        if self.res_scale:\n            x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n\n        if self.is_last:\n            return (x, None)\n        else:\n            l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)\n            return (l_x, h_x)\n    \n    def forward(self, x, h_x, h_r):\n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)\n        else:\n            x = self._forward_inner(x, h_x, h_r)\n        return x\n\n\nclass OverLoCK(nn.Module):\n    '''\n    An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels\n    https://arxiv.org/abs/2502.20087\n    '''\n    def __init__(self, \n                 depth=[2, 2, 2, 2],\n                 sub_depth=[4, 2],\n                 in_chans=3, \n                 embed_dim=[96, 192, 384, 768],\n                 kernel_size=[7, 7, 7, 7],\n                 mlp_ratio=[4, 4, 4, 4],\n                 sub_mlp_ratio=[4, 4],\n                 sub_num_heads=[4, 8],\n                 ls_init_value=[None, None, 1, 1],\n                 res_scale=True,\n                 smk_size=5,\n                 deploy=False,\n                 use_gemm=True,\n                 use_ds=True,\n                 drop_rate=0,\n                 drop_path_rate=0,\n                 norm_layer=LayerNorm2d,\n                 projection=1024,\n                 num_classes=1000,\n                 use_checkpoint=[0, 0, 0, 0],\n            ):\n \n        super().__init__()\n        \n        fusion_dim = embed_dim[-1] + embed_dim[-1]//4\n        # self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim\n\n        self.patch_embed1 = stem(in_chans, embed_dim[0])\n        self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])\n        self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])\n        self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])\n        self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)\n        self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]\n\n        self.blocks1 = nn.ModuleList()\n        self.blocks2 = nn.ModuleList()\n        self.blocks3 = nn.ModuleList()\n        self.blocks4 = nn.ModuleList()\n        self.sub_blocks3 = nn.ModuleList()\n        self.sub_blocks4 = nn.ModuleList()\n        \n        for i in range(depth[0]):\n            self.blocks1.append(\n                RepConvBlock(\n                    dim=embed_dim[0],\n                    kernel_size=kernel_size[0],\n                    mlp_ratio=mlp_ratio[0],\n                    ls_init_value=ls_init_value[0],\n                    res_scale=res_scale,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[0]),\n                )\n            )\n        \n        for i in range(depth[1]):\n            self.blocks2.append(\n                RepConvBlock(\n                    dim=embed_dim[1],\n                    kernel_size=kernel_size[1],\n                    mlp_ratio=mlp_ratio[1],\n                    ls_init_value=ls_init_value[1],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+depth[0]],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[1]),\n                )\n            )\n            \n        for i in range(depth[2]):\n            self.blocks3.append(\n                RepConvBlock(\n                    dim=embed_dim[2],\n                    kernel_size=kernel_size[2],\n                    mlp_ratio=mlp_ratio[2],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:2])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n\n        for i in range(depth[3]):\n            self.blocks4.append(\n                RepConvBlock(\n                    dim=embed_dim[3],\n                    kernel_size=kernel_size[3],\n                    mlp_ratio=mlp_ratio[3],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:3])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n            \n        for i in range(sub_depth[0]):\n            self.sub_blocks3.append(\n                DynamicConvBlock(\n                    dim=embed_dim[2],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[2],\n                    num_heads=sub_num_heads[0],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[0],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    is_first=(i==0),\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n        \n        for i in range(sub_depth[1]):\n            self.sub_blocks4.append(\n                DynamicConvBlock(\n                    dim=embed_dim[3],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[-1],\n                    num_heads=sub_num_heads[1],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[1],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)+sub_depth[0]],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    is_first=False,\n                    is_last=(i==sub_depth[1]-1),\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n\n        self.h_proj = nn.Sequential(\n            nn.Conv2d(embed_dim[-1], fusion_dim, kernel_size=1),\n            LayerScale(fusion_dim, init_value=1e-5),\n        )\n        \n        # Aux Cls Head\n        if use_ds:\n            self.aux_head = nn.Sequential(\n                nn.BatchNorm2d(embed_dim[-1]),\n                nn.AdaptiveAvgPool2d(1),\n                nn.Conv2d(embed_dim[-1], num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n            )\n        \n        # Main Cls Head\n        self.head = nn.Sequential(\n            nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),\n            nn.BatchNorm2d(projection),\n            nn.SiLU(),\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n        )\n        \n        self.extra_norm = nn.ModuleList()\n        for idx in range(4):\n            dim = embed_dim[idx]\n            if idx >= 2:\n                dim = dim + embed_dim[-1]//4\n            self.extra_norm.append(norm_layer(dim))\n        self.extra_norm.append(norm_layer(embed_dim[-1]))\n        \n        del self.aux_head\n        del self.head\n        \n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):\n            nn.init.trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):\n            nn.init.constant_(m.weight, 1.0)\n            nn.init.constant_(m.bias, 0)\n    \n    def _convert_sync_batchnorm(self):\n        if torch.distributed.is_initialized():\n            self = nn.SyncBatchNorm.convert_sync_batchnorm(self)\n            \n    def forward_pre_features(self, x):\n        \n        outs = []\n        \n        x = self.patch_embed1(x)\n        for blk in self.blocks1:\n            x = blk(x)\n        \n        outs.append(self.extra_norm[0](x))\n            \n        x = self.patch_embed2(x)\n        for blk in self.blocks2:\n            x = blk(x)\n\n        outs.append(self.extra_norm[1](x))\n        \n        return outs\n    \n    \n    def forward_base_features(self, x):\n        \n        x = self.patch_embed3(x)\n        for blk in self.blocks3:\n            x = blk(x)\n            \n        ctx = self.patch_embed4(x)\n        for blk in self.blocks4:\n            ctx = blk(ctx)\n\n        return (x, ctx)\n    \n\n    def forward_sub_features(self, x, ctx):\n        \n        outs = []\n        \n        ctx_cls = ctx\n        ctx_ori = self.high_level_proj(ctx)\n        ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)\n        \n        for idx, blk in enumerate(self.sub_blocks3):\n            if idx == 0:\n                ctx = ctx_up\n            x, ctx = blk(x, ctx, ctx_up)\n        \n        outs.append(self.extra_norm[2](torch.cat([x, ctx], dim=1)))\n        \n        x, ctx = self.patch_embedx(x, ctx)\n        for idx, blk in enumerate(self.sub_blocks4):\n            x, ctx = blk(x, ctx, ctx_ori)\n        \n        ctx = self.extra_norm[-1](ctx_cls)\n        x = self.extra_norm[3](x) + self.h_proj(ctx)\n        \n        outs.append(x)\n        \n        return outs\n\n    def forward_features(self, x):\n        \n        x0, x1 = self.forward_pre_features(x)\n        x, ctx = self.forward_base_features(x1)\n        x2, x3 = self.forward_sub_features(x, ctx)\n\n        return (x0, x1, x2, x3)\n\n    def forward(self, x):\n        \n        x = self.forward_features(x)\n        \n        return x\n\n\n@MODELS.register_module()\ndef overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[2, 2, 3, 2],\n        sub_depth=[6, 2],\n        embed_dim=[56, 112, 256, 336],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 6],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[4, 4, 6, 2],\n        sub_depth=[12, 2],\n        embed_dim=[64, 128, 256, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 8],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[6, 6, 8, 3],\n        sub_depth=[16, 3],\n        embed_dim=[64, 128, 320, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[8, 16],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[8, 8, 10, 4],\n        sub_depth=[20, 4],\n        embed_dim=[80, 160, 384, 576],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[6, 9],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n"
  },
  {
    "path": "detection/readme.md",
    "content": "# Applying OverLoCK to Object Detection and Instance Segmentation\n\n## 1. Requirements\n\n```\npip install mmcv-full==1.7.2 --no-cache-dir\npip install mmdet==2.28.2 --no-cache-dir\n```\n💡 To enable torch>=2.1.0 to support mmcv 1.7.2, you need to make the following changes:  \n> 1️⃣ https://goo.su/XhU5vWr     \n> 2️⃣ https://goo.su/ogm4yO\n\n\n## 2. Data Preparation\n\nPrepare COCO 2017 according to the [guidelines](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/1_exist_data_model.md). \n\n## 3. Main Results on COCO using Mask R-CNN framework\n\n|    Backbone   |   Pretrain  | Schedule | AP_b | AP_m | Config | Download |\n|:-------------:|:-----------:|:--------:|--------|:-------:|:------:|:----------:|\n| OverLoCK-T | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth)|    1x    |  48.3  |43.3     |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_tiny_coco.pth)          |\n|               |             |    3x    |49.6        |43.9      |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_3x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_tiny_coco.pth)          |\n| OverLoCK-S | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth)|    1x    |49.4        |44.0         |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_1x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_small_coco.pth)           |\n|               |             |    3x    |51.0        |45.0         |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_s_in1k_fpn_3x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_small_coco.pth)          |\n| OverLoCK-B | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth) |    1x    |49.9       |44.4         |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_1x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn1x_overlock_base_coco.pth)           |\n|               |             |    3x    |51.4       |45.3         |[config](configs/maskrcnn_overlock/mask_rcnn_overlock_b_in1k_fpn_3x_coco.py)        |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/maskrcnn3x_overlock_base_coco.pth)          |\n\n## 4. Train\nTo train ``OverLoCK-T + Mask R-CNN 1x`` model on COCO dataset with 8 GPUs (single node), run:\n```\nNUM_GPUS=8\nCONFIG=configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py\nbash scripts/dist_train.sh $CONFIG $NUM_GPUS\n```\n\n## 5. Validation\nTo evaluate ``OverLoCK-T + Mask R-CNN 1x`` model on COCO dataset, run:\n```\nNUM_GPUS=8\nCKPT=path-to-checkpoint.pth\nCONFIG=configs/maskrcnn_overlock/mask_rcnn_overlock_t_in1k_fpn_1x_coco.py\nbash scripts/dist_test.sh $CONFIG $CKPT $NUM_GPUS --eval bbox segm\n```\n\n## Citation\nIf you find this project useful for your research, please consider citing:\n```\n@inproceedings{lou2025overlock,\n  title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},\n  author={Lou, Meng and Yu, Yizhou},\n  booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n  pages={128--138},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "detection/scripts/dist_test.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=$((RANDOM+10000))\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\ntorchrun --nproc_per_node=$GPUS --master_port=$PORT test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}"
  },
  {
    "path": "detection/scripts/dist_train.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nGPUS=$2\nPORT=$((RANDOM+10000))\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\ntorchrun --nproc_per_node=$GPUS --master_port=$PORT train.py $CONFIG --launcher pytorch ${@:3}"
  },
  {
    "path": "detection/test.py",
    "content": "import argparse\nimport os\nimport os.path as osp\nimport time\nimport warnings\n\nimport mmcv\nimport torch\nfrom mmcv import Config, DictAction\nfrom mmcv.cnn import fuse_conv_bn\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmdet.apis import multi_gpu_test, single_gpu_test\nfrom mmdet.datasets import (build_dataloader, build_dataset,\n                            replace_ImageToTensor)\nfrom mmdet.models import build_detector\n\nimport models\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='MMDet test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument(\n        '--work-dir',\n        help='the directory to save the file containing evaluation metrics')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--fuse-conv-bn',\n        action='store_true',\n        help='Whether to fuse conv and bn, this will slightly increase'\n        'the inference speed')\n    parser.add_argument('--gpu-ids',\n                        type=int,\n                        nargs='+',\n                        help='ids of gpus to use '\n                        '(only applicable to non-distributed testing)')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"bbox\",'\n        ' \"segm\", \"proposal\" for COCO, and \"mAP\", \"recall\" for PASCAL VOC')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument('--show-dir',\n                        help='directory where painted images will be saved')\n    parser.add_argument('--show-score-thr',\n                        type=float,\n                        default=0.3,\n                        help='score threshold (default: 0.3)')\n    parser.add_argument('--gpu-collect',\n                        action='store_true',\n                        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu-collect is not specified')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function (deprecate), '\n        'change to --eval-options instead.')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation, the key-value pair in xxx=yyy '\n        'format will be kwargs for dataset.evaluate() function')\n    parser.add_argument('--launcher',\n                        choices=['none', 'pytorch', 'slurm', 'mpi'],\n                        default='none',\n                        help='job launcher')\n    parser.add_argument('--local_rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.eval_options:\n        raise ValueError(\n            '--options and --eval-options cannot be both '\n            'specified, --options is deprecated in favor of --eval-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --eval-options')\n        args.eval_options = args.options\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    cfg.model.pretrained = None\n    if cfg.model.get('neck'):\n        if isinstance(cfg.model.neck, list):\n            for neck_cfg in cfg.model.neck:\n                if neck_cfg.get('rfp_backbone'):\n                    if neck_cfg.rfp_backbone.get('pretrained'):\n                        neck_cfg.rfp_backbone.pretrained = None\n        elif cfg.model.neck.get('rfp_backbone'):\n            if cfg.model.neck.rfp_backbone.get('pretrained'):\n                cfg.model.neck.rfp_backbone.pretrained = None\n\n    # in case the test dataset is concatenated\n    samples_per_gpu = 1\n    if isinstance(cfg.data.test, dict):\n        cfg.data.test.test_mode = True\n        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)\n        if samples_per_gpu > 1:\n            # Replace 'ImageToTensor' to 'DefaultFormatBundle'\n            cfg.data.test.pipeline = replace_ImageToTensor(\n                cfg.data.test.pipeline)\n    elif isinstance(cfg.data.test, list):\n        for ds_cfg in cfg.data.test:\n            ds_cfg.test_mode = True\n        samples_per_gpu = max(\n            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])\n        if samples_per_gpu > 1:\n            for ds_cfg in cfg.data.test:\n                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)\n\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1)\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n        if len(cfg.gpu_ids) > 1:\n            warnings.warn(\n                f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '\n                f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '\n                'non-distribute testing time.')\n            cfg.gpu_ids = cfg.gpu_ids[0:1]\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    rank, _ = get_dist_info()\n    # allows not to create\n    if args.work_dir is not None and rank == 0:\n        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))\n        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')\n\n    # build the dataloader\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(dataset,\n                                   samples_per_gpu=samples_per_gpu,\n                                   workers_per_gpu=cfg.data.workers_per_gpu,\n                                   dist=distributed,\n                                   shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n\n    if args.fuse_conv_bn:\n        model = fuse_conv_bn(model)\n    # old versions did not save class info in checkpoints, this walkaround is\n    # for backward compatibility\n    if 'CLASSES' in checkpoint.get('meta', {}):\n        model.CLASSES = checkpoint['meta']['CLASSES']\n    else:\n        model.CLASSES = dataset.CLASSES\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=cfg.gpu_ids)\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  args.show_score_thr)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir,\n                                 args.gpu_collect)\n\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            eval_kwargs = cfg.get('evaluation', {}).copy()\n            # hard-code way to remove EvalHook args\n            for key in [\n                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',\n                    'rule', 'dynamic_intervals'\n            ]:\n                eval_kwargs.pop(key, None)\n            eval_kwargs.update(dict(metric=args.eval, **kwargs))\n            metric = dataset.evaluate(outputs, **eval_kwargs)\n            print(metric)\n            metric_dict = dict(config=args.config, metric=metric)\n            if args.work_dir is not None and rank == 0:\n                mmcv.dump(metric_dict, json_file)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "detection/train.py",
    "content": "import argparse\nimport copy\nimport os\nimport os.path as osp\nimport time\nimport warnings\n\nimport mmcv\nimport torch\nimport torch.distributed as dist\nfrom mmcv import Config, DictAction\nfrom mmcv.runner import get_dist_info, init_dist\nfrom mmcv.utils import get_git_hash\n\nfrom mmdet import __version__\nfrom mmdet.apis import init_random_seed, set_random_seed, train_detector\nfrom mmdet.datasets import build_dataset\nfrom mmdet.models import build_detector\nfrom mmdet.utils import (collect_env, get_device, get_root_logger,\n                         replace_cfg_vals, setup_multi_processes,\n                         update_data_root)\n\nimport models\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a detector')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument('--work-dir', help='the dir to save logs and models')\n    parser.add_argument('--resume-from',\n                        help='the checkpoint file to resume from')\n    parser.add_argument('--auto-resume',\n                        action='store_true',\n                        help='resume from the latest checkpoint automatically')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='(Deprecated, please use --gpu-id) number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='(Deprecated, please use --gpu-id) ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument('--gpu-id',\n                            type=int,\n                            default=0,\n                            help='id of gpu to use '\n                            '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--diff-seed',\n        action='store_true',\n        help='Whether or not set different seeds for different ranks')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')\n    parser.add_argument(\n        '--options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file (deprecate), '\n        'change to --cfg-options instead.')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='override some settings in the used config, the key-value pair '\n        'in xxx=yyy format will be merged into config file. If the value to '\n        'be overwritten is a list, it should be like key=\"[a,b]\" or key=a,b '\n        'It also allows nested list/tuple values, e.g. key=\"[(a,b),(c,d)]\" '\n        'Note that the quotation marks are necessary and that no white space '\n        'is allowed.')\n    parser.add_argument(\n        '--drop-path',\n        default=-1,\n        type=float,\n        help='drop-path-rate of the backbone network')\n    parser.add_argument(\n        '--freeze-bn',\n        action='store_true',\n        default=False,\n        help='freeze the BN layer of the backbone model during training')\n    parser.add_argument(\n        '--amp',\n        action='store_true',\n        default=False,\n        help='mixed precision training')\n    parser.add_argument('--launcher',\n                        choices=['none', 'pytorch', 'slurm', 'mpi'],\n                        default='none',\n                        help='job launcher')\n    parser.add_argument('--local-rank', type=int, default=0)\n    parser.add_argument('--auto-scale-lr',\n                        action='store_true',\n                        help='enable automatically scaling LR.')\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.options and args.cfg_options:\n        raise ValueError(\n            '--options and --cfg-options cannot be both '\n            'specified, --options is deprecated in favor of --cfg-options')\n    if args.options:\n        warnings.warn('--options is deprecated in favor of --cfg-options')\n        args.cfg_options = args.options\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n\n    # replace the ${key} with the value of cfg.key\n    cfg = replace_cfg_vals(cfg)\n\n    # update data root according to MMDET_DATASETS\n    update_data_root(cfg)\n\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n\n    if args.auto_scale_lr:\n        if 'auto_scale_lr' in cfg and \\\n                'enable' in cfg.auto_scale_lr and \\\n                'base_batch_size' in cfg.auto_scale_lr:\n            cfg.auto_scale_lr.enable = True\n        else:\n            warnings.warn('Can not find \"auto_scale_lr\" or '\n                          '\"auto_scale_lr.enable\" or '\n                          '\"auto_scale_lr.base_batch_size\" in your'\n                          ' configuration file. Please update all the '\n                          'configuration files to mmdet >= 2.24.1.')\n\n    # set multi-process settings\n    setup_multi_processes(cfg)\n\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs',\n                                osp.splitext(osp.basename(args.config))[0])\n\n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    cfg.auto_resume = args.auto_resume\n    if args.gpus is not None:\n        cfg.gpu_ids = range(1)\n        warnings.warn('`--gpus` is deprecated because we only support '\n                      'single GPU mode in non-distributed training. '\n                      'Use `gpus=1` now.')\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids[0:1]\n        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '\n                      'Because we only support single GPU mode in '\n                      'non-distributed training. Use the first GPU '\n                      'in `gpu_ids` now.')\n    if args.gpus is None and args.gpu_ids is None:\n        cfg.gpu_ids = [args.gpu_id]\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n        # re-set gpu_ids with distributed training mode\n        _, world_size = get_dist_info()\n        cfg.gpu_ids = range(world_size)\n\n    if args.freeze_bn:\n        try:\n            cfg.model.backbone.freeze_bn = True\n        except:\n            logger.info('freeze_bn is not defined in the config file')    \n    \n    if args.drop_path >= 0:\n        try:\n            cfg.model.backbone.drop_path_rate = args.drop_path\n        except:\n            logger.info('drop_path is not defined in the config file')    \n    \n    if args.amp:\n        loss_scale = 'dynamic'\n        if cfg.get('fp16', None) is None:\n            cfg.fp16 = dict(loss_scale=loss_scale)\n            # warnings.warn('fp16 is not defined in the config file')\n        else:\n            # cfg.fp16.enabled = True\n            # cfg.fp16.loss_scale = loss_scale\n            warnings.warn('fp16 has been defined in the config file')\n            \n        # cfg.optimizer_config.type = 'Fp16OptimizerHook'\n        # cfg.optimizer_config.loss_scale = loss_scale    \n    \n    \n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # dump config\n    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n\n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n    meta['config'] = cfg.pretty_text\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    cfg.device = get_device()\n    # set random seeds\n    seed = init_random_seed(args.seed, device=cfg.device)\n    seed = seed + dist.get_rank() if args.diff_seed else seed\n    logger.info(f'Set random seed to {seed}, '\n                f'deterministic: {args.deterministic}')\n    set_random_seed(seed, deterministic=args.deterministic)\n    cfg.seed = seed\n    meta['seed'] = seed\n    meta['exp_name'] = osp.basename(args.config)\n\n    model = build_detector(cfg.model,\n                           train_cfg=cfg.get('train_cfg'),\n                           test_cfg=cfg.get('test_cfg'))\n    # model.init_weights()\n\n    logger.info(model)\n    \n    datasets = [build_dataset(cfg.data.train)]\n    if len(cfg.workflow) == 2:\n        val_dataset = copy.deepcopy(cfg.data.val)\n        val_dataset.pipeline = cfg.data.train.pipeline\n        datasets.append(build_dataset(val_dataset))\n    if cfg.checkpoint_config is not None:\n        # save mmdet version, config file content and class names in\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(mmdet_version=__version__ +\n                                          get_git_hash()[:7],\n                                          CLASSES=datasets[0].CLASSES)\n    # add an attribute for visualization convenience\n    model.CLASSES = datasets[0].CLASSES\n\n    # torch.backends.cuda.matmul.allow_tf32 = True\n    # torch.backends.cudnn.allow_tf32 = True\n\n    train_detector(model,\n                   datasets,\n                   cfg,\n                   distributed=distributed,\n                   validate=(not args.no_validate),\n                   timestamp=timestamp,\n                   meta=meta)\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "models/__init__.py",
    "content": "from .overlock import overlock_xt, overlock_t, overlock_s, overlock_b"
  },
  {
    "path": "models/contmix.py",
    "content": "'''\r\nThis is a plug-and-play implementation of ContMix block in the paper:\r\nhttps://arxiv.org/abs/2502.20087\r\n'''\r\nimport warnings\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch import nn\r\nfrom einops import rearrange, einsum\r\nfrom timm.models.layers import DropPath, to_2tuple\r\nfrom torch.utils.checkpoint import checkpoint\r\n\r\ntry:\r\n    from natten.functional import na2d_av\r\n    has_natten = True\r\nexcept:\r\n    has_natten = False\r\n    warnings.warn(\"The efficiency may be reduced since 'natten' is not installed.\"\r\n                  \" It is recommended to install natten for better performance.\")\r\n\r\n\r\ndef get_conv2d(in_channels,\r\n               out_channels,\r\n               kernel_size,\r\n               stride,\r\n               padding,\r\n               dilation,\r\n               groups,\r\n               bias,\r\n               attempt_use_lk_impl=True):\r\n\r\n    kernel_size = to_2tuple(kernel_size)\r\n    if padding is None:\r\n        padding = (kernel_size[0] // 2, kernel_size[1] // 2)\r\n    else:\r\n        padding = to_2tuple(padding)\r\n    need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)\r\n\r\n    if attempt_use_lk_impl and need_large_impl:\r\n        print('---------------- trying to import iGEMM implementation for large-kernel conv')\r\n        try:\r\n            from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM\r\n            print('---------------- found iGEMM implementation ')\r\n        except:\r\n            DepthWiseConv2dImplicitGEMM = None\r\n            print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')\r\n        if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \\\r\n                and out_channels == groups and stride == 1 and dilation == 1:\r\n            print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')\r\n            return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)\r\n\r\n    return nn.Conv2d(in_channels, out_channels,\r\n                     kernel_size=kernel_size,\r\n                     stride=stride,\r\n                     padding=padding,\r\n                     dilation=dilation,\r\n                     groups=groups,\r\n                     bias=bias)\r\n\r\n\r\ndef get_bn(dim, use_sync_bn=False):\r\n    if use_sync_bn:\r\n        return nn.SyncBatchNorm(dim)\r\n    else:\r\n        return nn.BatchNorm2d(dim)\r\n\r\n\r\ndef fuse_bn(conv, bn):\r\n    conv_bias = 0 if conv.bias is None else conv.bias\r\n    std = (bn.running_var + bn.eps).sqrt()\r\n    return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std\r\n\r\ndef convert_dilated_to_nondilated(kernel, dilate_rate):\r\n    identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)\r\n    if kernel.size(1) == 1:\r\n        #   This is a DW kernel\r\n        dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)\r\n        return dilated\r\n    else:\r\n        #   This is a dense or group-wise (but not DW) kernel\r\n        slices = []\r\n        for i in range(kernel.size(1)):\r\n            dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)\r\n            slices.append(dilated)\r\n        return torch.cat(slices, dim=1)\r\n\r\ndef merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):\r\n    large_k = large_kernel.size(2)\r\n    dilated_k = dilated_kernel.size(2)\r\n    equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1\r\n    equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)\r\n    rows_to_pad = large_k // 2 - equivalent_kernel_size // 2\r\n    merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)\r\n    return merged_kernel\r\n\r\n\r\nclass SEModule(nn.Module):\r\n    def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):\r\n        super().__init__()\r\n        inner_dim = max(16, dim // red)\r\n        self.proj = nn.Sequential(\r\n            nn.AdaptiveAvgPool2d(1),\r\n            nn.Conv2d(dim, inner_dim, kernel_size=1),\r\n            inner_act(),\r\n            nn.Conv2d(inner_dim, dim, kernel_size=1),\r\n            out_act(),\r\n        )\r\n\r\n    def forward(self, x):\r\n        x = x * self.proj(x)\r\n        return x\r\n\r\n\r\nclass LayerScale(nn.Module):\r\n    def __init__(self, dim, init_value=1e-5):\r\n        super().__init__()\r\n        self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,\r\n                                   requires_grad=True)\r\n        self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)\r\n\r\n    def forward(self, x):\r\n\r\n        x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])\r\n\r\n        return x\r\n\r\n\r\nclass LayerNorm2d(nn.LayerNorm):\r\n    def __init__(self, dim):\r\n        super().__init__(normalized_shape=dim, eps=1e-6)\r\n\r\n    def forward(self, x):\r\n        x = rearrange(x, 'b c h w -> b h w c')\r\n        x = super().forward(x)\r\n        x = rearrange(x, 'b h w c -> b c h w')\r\n        return x.contiguous()\r\n\r\n\r\n\r\nclass GRN(nn.Module):\r\n    \"\"\" GRN (Global Response Normalization) layer\r\n    Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)\r\n    This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)\r\n    We assume the inputs to this layer are (N, C, H, W)\r\n    \"\"\"\r\n    def __init__(self, dim, use_bias=True):\r\n        super().__init__()\r\n        self.use_bias = use_bias\r\n        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))\r\n        if self.use_bias:\r\n            self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))\r\n\r\n\r\n    def forward(self, x):\r\n        Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)\r\n        Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)\r\n        if self.use_bias:\r\n            return (self.gamma * Nx + 1) * x + self.beta\r\n        else:\r\n            return (self.gamma * Nx + 1) * x\r\n\r\n\r\n\r\nclass DilatedReparamBlock(nn.Module):\r\n    \"\"\"\r\n    Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)\r\n    We assume the inputs to this block are (N, C, H, W)\r\n    \"\"\"\r\n    def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):\r\n        super().__init__()\r\n        self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,\r\n                                    padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,\r\n                                    attempt_use_lk_impl=attempt_use_lk_impl)\r\n        self.attempt_use_lk_impl = attempt_use_lk_impl\r\n\r\n        #   Default settings. We did not tune them carefully. Different settings may work better.\r\n        if kernel_size == 19:\r\n            self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]\r\n            self.dilates = [1, 1, 1, 2, 4, 5, 7]\r\n        elif kernel_size == 17:\r\n            self.kernel_sizes = [5, 7, 9, 3, 3, 3]\r\n            self.dilates = [1, 1, 2, 4, 5, 7]\r\n        elif kernel_size == 15:\r\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\r\n            self.dilates = [1, 1, 2, 3, 5, 7]\r\n        elif kernel_size == 13:\r\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\r\n            self.dilates = [1, 1, 2, 3, 4, 5]\r\n        elif kernel_size == 11:\r\n            self.kernel_sizes = [5, 7, 5, 3, 3, 3]\r\n            self.dilates = [1, 1, 2, 3, 4, 5]\r\n        elif kernel_size == 9:\r\n            self.kernel_sizes = [5, 7, 5, 3, 3]\r\n            self.dilates = [1, 1, 2, 3, 4]\r\n        elif kernel_size == 7:\r\n            self.kernel_sizes = [5, 3, 3, 3]\r\n            self.dilates = [1, 1, 2, 3]\r\n        elif kernel_size == 5:\r\n            self.kernel_sizes = [3, 3]\r\n            self.dilates = [1, 2]\r\n        else:\r\n            raise ValueError('Dilated Reparam Block requires kernel_size >= 5')\r\n\r\n        if not deploy:\r\n            self.origin_bn = get_bn(channels, use_sync_bn)\r\n            for k, r in zip(self.kernel_sizes, self.dilates):\r\n                self.__setattr__('dil_conv_k{}_{}'.format(k, r),\r\n                                 nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,\r\n                                           padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,\r\n                                           bias=False))\r\n                self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))\r\n\r\n    def forward(self, x):\r\n        if not hasattr(self, 'origin_bn'):      # deploy mode\r\n            return self.lk_origin(x)\r\n        out = self.origin_bn(self.lk_origin(x))\r\n        for k, r in zip(self.kernel_sizes, self.dilates):\r\n            conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\r\n            bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\r\n            out = out + bn(conv(x))\r\n        return out\r\n\r\n    def merge_dilated_branches(self):\r\n        if hasattr(self, 'origin_bn'):\r\n            origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)\r\n            for k, r in zip(self.kernel_sizes, self.dilates):\r\n                conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\r\n                bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\r\n                branch_k, branch_b = fuse_bn(conv, bn)\r\n                origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)\r\n                origin_b += branch_b\r\n            merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,\r\n                                    padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,\r\n                                    attempt_use_lk_impl=self.attempt_use_lk_impl)\r\n            merged_conv.weight.data = origin_k\r\n            merged_conv.bias.data = origin_b\r\n            self.lk_origin = merged_conv\r\n            self.__delattr__('origin_bn')\r\n            for k, r in zip(self.kernel_sizes, self.dilates):\r\n                self.__delattr__('dil_conv_k{}_{}'.format(k, r))\r\n                self.__delattr__('dil_bn_k{}_{}'.format(k, r))\r\n\r\n\r\nclass ResDWConv(nn.Conv2d):\r\n    '''\r\n    Depthwise conv with residual connection\r\n    '''\r\n    def __init__(self, dim, kernel_size=3):\r\n        super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)\r\n\r\n    def forward(self, x):\r\n        x = x + super().forward(x)\r\n        return x\r\n\r\n\r\nclass ContMixBlock(nn.Module):\r\n    '''\r\n    A plug-and-play implementation of ContMix module with FFN layer\r\n    Paper: https://arxiv.org/abs/2502.20087\r\n    '''\r\n    def __init__(self,\r\n                 dim=64,\r\n                 kernel_size=7,\r\n                 smk_size=5,\r\n                 num_heads=2,\r\n                 mlp_ratio=4,\r\n                 res_scale=False,\r\n                 ls_init_value=None,\r\n                 drop_path=0,\r\n                 norm_layer=LayerNorm2d,\r\n                 use_gemm=False,\r\n                 deploy=False,\r\n                 use_checkpoint=False,\r\n                 **kwargs):\r\n\r\n        super().__init__()\r\n        '''\r\n        Args:\r\n        kernel_size: kernel size of the main ContMix branch, default is 7\r\n        smk_size: kernel size of the secondary ContMix branch, default is 5\r\n        num_heads: number of dynamic kernel heads, default is 2\r\n        mlp_ratio: ratio of mlp hidden dim to embedding dim, default is 4\r\n        res_scale: whether to use residual layer scale, default is False\r\n        ls_init_value: layer scale init value, default is None\r\n        drop_path: drop path rate, default is 0\r\n        norm_layer: normalization layer, default is LayerNorm2d\r\n        use_gemm: whether to use iGEMM implementation for large kernel conv, default is False\r\n        deploy: whether to use deploy mode, default is False\r\n        use_checkpoint: whether to use grad checkpointing, default is False\r\n        **kwargs: other arguments\r\n        '''\r\n        mlp_dim = int(dim*mlp_ratio)\r\n        self.kernel_size = kernel_size\r\n        self.res_scale = res_scale\r\n        self.use_gemm = use_gemm\r\n        self.smk_size = smk_size\r\n        self.num_heads = num_heads * 2\r\n        head_dim = dim // self.num_heads\r\n        self.scale = head_dim ** -0.5\r\n        self.use_checkpoint = use_checkpoint\r\n\r\n        self.dwconv1 = ResDWConv(dim, kernel_size=3)\r\n        self.norm1 = norm_layer(dim)\r\n\r\n        self.weight_query = nn.Sequential(\r\n            nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),\r\n            nn.BatchNorm2d(dim//2),\r\n        )\r\n        self.weight_key = nn.Sequential(\r\n            nn.AdaptiveAvgPool2d(7),\r\n            nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),\r\n            nn.BatchNorm2d(dim//2),\r\n        )\r\n        self.weight_value = nn.Sequential(\r\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\r\n            nn.BatchNorm2d(dim),\r\n        )\r\n\r\n        self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)\r\n        self.fusion_proj = nn.Sequential(\r\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\r\n            nn.BatchNorm2d(dim),\r\n        )\r\n\r\n        self.lepe = nn.Sequential(\r\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\r\n            nn.BatchNorm2d(dim),\r\n        )\r\n        self.se_layer = SEModule(dim)\r\n        self.gate = nn.Sequential(\r\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\r\n            nn.BatchNorm2d(dim),\r\n            nn.SiLU(),\r\n        )\r\n\r\n        self.proj = nn.Sequential(\r\n            nn.BatchNorm2d(dim),\r\n            nn.Conv2d(dim, dim, kernel_size=1),\r\n        )\r\n\r\n        self.dwconv2 = ResDWConv(dim, kernel_size=3)\r\n        self.norm2 = norm_layer(dim)\r\n        self.mlp = nn.Sequential(\r\n            nn.Conv2d(dim, mlp_dim, kernel_size=1),\r\n            nn.GELU(),\r\n            ResDWConv(mlp_dim, kernel_size=3),\r\n            GRN(mlp_dim),\r\n            nn.Conv2d(mlp_dim, dim, kernel_size=1),\r\n        )\r\n\r\n        self.ls1 = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\r\n        self.ls2 = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()\r\n\r\n        self.get_rpb()\r\n\r\n    def get_rpb(self):\r\n        self.rpb_size1 = 2 * self.smk_size - 1\r\n        self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))\r\n        self.rpb_size2 = 2 * self.kernel_size - 1\r\n        self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))\r\n        nn.init.trunc_normal_(self.rpb1, std=0.02)\r\n        nn.init.trunc_normal_(self.rpb2, std=0.02)\r\n\r\n    @torch.no_grad()\r\n    def generate_idx(self, kernel_size):\r\n        rpb_size = 2 * kernel_size - 1\r\n        idx_h = torch.arange(0, kernel_size)\r\n        idx_w = torch.arange(0, kernel_size)\r\n        idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)\r\n        return (idx_h, idx_w, idx_k)\r\n\r\n    def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):\r\n        \"\"\"\r\n        RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3\r\n        \"\"\"\r\n        num_repeat_h = torch.ones(kernel_size, dtype=torch.long)\r\n        num_repeat_w = torch.ones(kernel_size, dtype=torch.long)\r\n        num_repeat_h[kernel_size//2] = height - (kernel_size-1)\r\n        num_repeat_w[kernel_size//2] = width - (kernel_size-1)\r\n        bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)\r\n        bias_idx = bias_hw.unsqueeze(-1) + idx_k\r\n        bias_idx = bias_idx.reshape(-1, int(kernel_size**2))\r\n        bias_idx = torch.flip(bias_idx, [0])\r\n        rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]\r\n        rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))\r\n        return attn + rpb\r\n\r\n    def reparm(self):\r\n        for m in self.modules():\r\n            if isinstance(m, DilatedReparamBlock):\r\n                m.merge_dilated_branches()\r\n\r\n    def _forward_inner(self, x):\r\n        input_resolution = x.shape[2:]\r\n        B, C, H, W = x.shape\r\n\r\n        x = self.dwconv1(x)\r\n        identity = x\r\n        x = self.norm1(x)\r\n        gate = self.gate(x)\r\n        lepe = self.lepe(x)\r\n\r\n        is_pad = False\r\n        if min(H, W) < self.kernel_size:\r\n            is_pad = True\r\n            if H < W:\r\n                size = (self.kernel_size, int(self.kernel_size / H * W))\r\n            else:\r\n                size = (int(self.kernel_size / W * H), self.kernel_size)\r\n            x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)\r\n            H, W = size\r\n\r\n        query = self.weight_query(x) * self.scale\r\n        key = self.weight_key(x)\r\n        value = self.weight_value(x)\r\n\r\n        query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\r\n        key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\r\n        weight = einsum(query, key, 'b g c n, b g c l -> b g n l')\r\n        weight = rearrange(weight, 'b g n l -> b l g n').contiguous()\r\n        weight = self.weight_proj(weight)\r\n        weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)\r\n\r\n        attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)\r\n        rpb1_idx = self.generate_idx(self.smk_size)\r\n        rpb2_idx = self.generate_idx(self.kernel_size)\r\n        attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)\r\n        attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)\r\n        attn1 = torch.softmax(attn1, dim=-1)\r\n        attn2 = torch.softmax(attn2, dim=-1)\r\n        value = rearrange(value, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)\r\n\r\n        if has_natten:\r\n            x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)\r\n            x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)\r\n        else:\r\n            pad1 = self.smk_size // 2\r\n            pad2 = self.kernel_size // 2\r\n            H_o1 = H - 2 * pad1\r\n            W_o1 = W - 2 * pad1\r\n            H_o2 = H - 2 * pad2\r\n            W_o2 = W - 2 * pad2\r\n\r\n            v1 = rearrange(value[0], 'b g h w c -> b (g c) h w')\r\n            v2 = rearrange(value[1], 'b g h w c -> b (g c) h w')\r\n\r\n            v1 = F.unfold(v1, kernel_size=self.smk_size).reshape(B, -1, H_o1, W_o1)\r\n            v2 = F.unfold(v2, kernel_size=self.kernel_size).reshape(B, -1, H_o2, W_o2)\r\n\r\n            v1 = F.pad(v1, (pad1, pad1, pad1, pad1), mode='replicate')\r\n            v2 = F.pad(v2, (pad2, pad2, pad2, pad2), mode='replicate')\r\n\r\n            v1 = rearrange(v1, 'b (g c k) h w -> b g c h w k', g=self.num_heads, k=self.smk_size**2, h=H, w=W)\r\n            v2 = rearrange(v2, 'b (g c k) h w -> b g c h w k', g=self.num_heads, k=self.kernel_size**2, h=H, w=W)\r\n\r\n            x1 = einsum(attn1, v1, 'b g h w k, b g c h w k -> b g h w c')\r\n            x2 = einsum(attn2, v2, 'b g h w k, b g c h w k -> b g h w c')\r\n\r\n        x = torch.cat([x1, x2], dim=1)\r\n        x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)\r\n\r\n        if is_pad:\r\n            x = F.adaptive_avg_pool2d(x, input_resolution)\r\n\r\n        x = self.fusion_proj(x)\r\n\r\n        x = x + lepe\r\n        x = self.se_layer(x)\r\n\r\n        x = gate * x\r\n        x = self.proj(x)\r\n\r\n        if self.res_scale:\r\n            x = self.ls1(identity) + self.drop_path(x)\r\n        else:\r\n            x = identity + self.drop_path(self.ls1(x))\r\n\r\n        x = self.dwconv2(x)\r\n\r\n        if self.res_scale:\r\n            x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))\r\n        else:\r\n            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\r\n\r\n        return x\r\n\r\n    def forward(self, x):\r\n        if self.use_checkpoint and x.requires_grad:\r\n            x = checkpoint(self._forward_inner, x, use_reentrant=False)\r\n        else:\r\n            x = self._forward_inner(x)\r\n        return x\r\n\r\n\r\nif __name__ == '__main__':\r\n\r\n    from timm.utils import random_seed\r\n    random_seed(6)\r\n\r\n    x = torch.randn(1, 64, 32, 32).cuda()\r\n    model = ContMixBlock(dim=64,\r\n                         num_heads=2,\r\n                         kernel_size=13,\r\n                         smk_size=5,\r\n                         mlp_ratio=4,\r\n                         res_scale=True,\r\n                         ls_init_value=1,\r\n                         drop_path=0,\r\n                         norm_layer=LayerNorm2d,\r\n                         use_gemm=True,\r\n                         deploy=False,\r\n                         use_checkpoint=False)\r\n    print(model)\r\n    model.cuda()\r\n    model.eval()\r\n    y = model(x)\r\n    print(y.shape)\r\n\r\n    # Reparametrize model, more details can be found at:\r\n    # https://github.com/AILab-CVC/UniRepLKNet/tree/main\r\n    model.reparm()\r\n    z = model(x)\r\n\r\n    # Showing difference between original and reparametrized model\r\n    print((z - y).abs().sum() / y.abs().sum())"
  },
  {
    "path": "models/overlock.py",
    "content": "'''\nThis is an official implementation of OverLoCK model proposed in the paper: \nhttps://arxiv.org/abs/2502.20087\n'''\nimport torch\nimport timm\nimport torch.distributed\nimport torch.nn.functional as F\nfrom torch import nn\nfrom einops import rearrange, einsum\nfrom natten.functional import na2d_av\nfrom mmengine.runner import load_checkpoint\nfrom torch.utils.checkpoint import checkpoint\nfrom timm.models.layers import DropPath, to_2tuple\nfrom timm.models.registry import register_model\n\n\ndef get_conv2d(in_channels, \n               out_channels, \n               kernel_size, \n               stride, \n               padding, \n               dilation, \n               groups, \n               bias,\n               attempt_use_lk_impl=True):\n    \n    kernel_size = to_2tuple(kernel_size)\n    if padding is None:\n        padding = (kernel_size[0] // 2, kernel_size[1] // 2)\n    else:\n        padding = to_2tuple(padding)\n    need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)\n\n    if attempt_use_lk_impl and need_large_impl:\n        print('---------------- trying to import iGEMM implementation for large-kernel conv')\n        try:\n            from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM\n            print('---------------- found iGEMM implementation ')\n        except:\n            DepthWiseConv2dImplicitGEMM = None\n            print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')\n        if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \\\n                and out_channels == groups and stride == 1 and dilation == 1:\n            print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')\n            return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)\n    \n    return nn.Conv2d(in_channels, out_channels, \n                     kernel_size=kernel_size, \n                     stride=stride,\n                     padding=padding, \n                     dilation=dilation, \n                     groups=groups, \n                     bias=bias)\n\n\ndef get_bn(dim, use_sync_bn=False):\n    if use_sync_bn:\n        return nn.SyncBatchNorm(dim)\n    else:\n        return nn.BatchNorm2d(dim)\n\n\ndef fuse_bn(conv, bn):\n    conv_bias = 0 if conv.bias is None else conv.bias\n    std = (bn.running_var + bn.eps).sqrt()\n    return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std\n\ndef convert_dilated_to_nondilated(kernel, dilate_rate):\n    identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)\n    if kernel.size(1) == 1:\n        #   This is a DW kernel\n        dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)\n        return dilated\n    else:\n        #   This is a dense or group-wise (but not DW) kernel\n        slices = []\n        for i in range(kernel.size(1)):\n            dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)\n            slices.append(dilated)\n        return torch.cat(slices, dim=1)\n\ndef merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):\n    large_k = large_kernel.size(2)\n    dilated_k = dilated_kernel.size(2)\n    equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1\n    equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)\n    rows_to_pad = large_k // 2 - equivalent_kernel_size // 2\n    merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)\n    return merged_kernel\n\n\ndef stem(in_chans=3, embed_dim=96):\n    return nn.Sequential(\n        nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim),\n        nn.GELU(),\n        nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim)\n    )\n\n\ndef downsample(in_dim, out_dim):\n    return nn.Sequential(\n        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(out_dim),\n    )        \n\n\nclass SEModule(nn.Module):\n    def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):\n        super().__init__()\n        inner_dim = max(16, dim // red)\n        self.proj = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(dim, inner_dim, kernel_size=1),\n            inner_act(),\n            nn.Conv2d(inner_dim, dim, kernel_size=1),\n            out_act(),\n        )\n        \n    def forward(self, x):\n        x = x * self.proj(x)\n        return x\n\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_value=1e-5):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value, \n                                   requires_grad=True)\n        self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)\n\n    def forward(self, x):\n        x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])\n        return x\n\n        \nclass LayerNorm2d(nn.LayerNorm):\n    def __init__(self, dim):\n        super().__init__(normalized_shape=dim, eps=1e-6)\n    \n    def forward(self, x):\n        x = rearrange(x, 'b c h w -> b h w c')\n        x = super().forward(x)\n        x = rearrange(x, 'b h w c -> b c h w')\n        return x.contiguous()\n\n\nclass GRN(nn.Module):\n    \"\"\" GRN (Global Response Normalization) layer\n    Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)\n    This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)\n    We assume the inputs to this layer are (N, C, H, W)\n    \"\"\"\n    def __init__(self, dim, use_bias=True):\n        super().__init__()\n        self.use_bias = use_bias\n        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))\n        if self.use_bias:\n            self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))\n\n    def forward(self, x):\n        Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)\n        Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)\n        if self.use_bias:\n            return (self.gamma * Nx + 1) * x + self.beta\n        else:\n            return (self.gamma * Nx + 1) * x\n    \n\n\nclass DilatedReparamBlock(nn.Module):\n    \"\"\"\n    Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)\n    We assume the inputs to this block are (N, C, H, W)\n    \"\"\"\n    def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):\n        super().__init__()\n        self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,\n                                    padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,\n                                    attempt_use_lk_impl=attempt_use_lk_impl)\n        self.attempt_use_lk_impl = attempt_use_lk_impl\n\n        #   Default settings. We did not tune them carefully. Different settings may work better.\n        if kernel_size == 19:\n            self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]\n            self.dilates = [1, 1, 1, 2, 4, 5, 7]\n        elif kernel_size == 17:\n            self.kernel_sizes = [5, 7, 9, 3, 3, 3]\n            self.dilates = [1, 1, 2, 4, 5, 7]\n        elif kernel_size == 15:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 5, 7]\n        elif kernel_size == 13:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 11:\n            self.kernel_sizes = [5, 7, 5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 9:\n            self.kernel_sizes = [5, 7, 5, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4]\n        elif kernel_size == 7:\n            self.kernel_sizes = [5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3]\n        elif kernel_size == 5:\n            self.kernel_sizes = [3, 3]\n            self.dilates = [1, 2]\n        else:\n            raise ValueError('Dilated Reparam Block requires kernel_size >= 5')\n\n        if not deploy:\n            self.origin_bn = get_bn(channels, use_sync_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__setattr__('dil_conv_k{}_{}'.format(k, r),\n                                 nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,\n                                           padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,\n                                           bias=False))\n                self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))\n\n    def forward(self, x):\n        if not hasattr(self, 'origin_bn'): # deploy mode\n            return self.lk_origin(x)\n        out = self.origin_bn(self.lk_origin(x))\n        for k, r in zip(self.kernel_sizes, self.dilates):\n            conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n            bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n            out = out + bn(conv(x))\n        return out\n\n    def merge_dilated_branches(self):\n        if hasattr(self, 'origin_bn'):\n            origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n                bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n                branch_k, branch_b = fuse_bn(conv, bn)\n                origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)\n                origin_b += branch_b\n            merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,\n                                    padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,\n                                    attempt_use_lk_impl=self.attempt_use_lk_impl)\n            merged_conv.weight.data = origin_k\n            merged_conv.bias.data = origin_b\n            self.lk_origin = merged_conv\n            self.__delattr__('origin_bn')\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__delattr__('dil_conv_k{}_{}'.format(k, r))\n                self.__delattr__('dil_bn_k{}_{}'.format(k, r))\n       \n\nclass CTXDownsample(nn.Module):\n    def __init__(self, dim, h_dim):\n        super().__init__()\n        \n        self.x_proj = nn.Sequential(\n            nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim)\n        )\n        self.h_proj = nn.Sequential(\n            nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim//4)\n        )\n\n    def forward(self, x, ctx):\n        x = self.x_proj(x)\n        ctx = self.h_proj(ctx)\n        return (x, ctx)\n\n\nclass ResDWConv(nn.Conv2d):\n    '''\n    Depthwise convolution with residual connection\n    '''\n    def __init__(self, dim, kernel_size=3):\n        super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)\n    \n    def forward(self, x):\n        x = x + super().forward(x)\n        return x\n\n\nclass RepConvBlock(nn.Module):\n\n    def __init__(self, \n                 dim=64,\n                 kernel_size=7,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False):\n        super().__init__()\n        \n        self.res_scale = res_scale\n        self.use_checkpoint = use_checkpoint\n        \n        mlp_dim = int(dim*mlp_ratio)\n        \n        self.dwconv = ResDWConv(dim, kernel_size=3)\n    \n        self.proj = nn.Sequential(\n            norm_layer(dim),\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n            SEModule(dim),\n            nn.Conv2d(dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, dim, kernel_size=1),\n            DropPath(drop_path) if drop_path > 0 else nn.Identity(),\n        )\n\n        self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        \n    def forward_features(self, x):\n        \n        x = self.dwconv(x)\n        \n        if self.res_scale:\n            x = self.ls(x) + self.proj(x)\n        else:\n            drop_path = self.proj[-1]\n            x = x + drop_path(self.ls(self.proj[:-1](x)))\n\n        return x\n    \n    def forward(self, x):\n        \n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self.forward_features, x, use_reentrant=False)\n        else:\n            x = self.forward_features(x)\n        \n        return x\n\n\nclass DynamicConvBlock(nn.Module):\n    def __init__(self,\n                 dim=64,\n                 ctx_dim=32,\n                 kernel_size=7,\n                 smk_size=5,\n                 num_heads=2,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 is_first=False,\n                 is_last=False,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False,\n                 **kwargs):\n        \n        super().__init__()\n        \n        ctx_dim = ctx_dim // 4\n        out_dim = dim + ctx_dim\n        mlp_dim = int(dim*mlp_ratio)\n        self.kernel_size = kernel_size\n        self.res_scale = res_scale\n        self.use_gemm = use_gemm\n        self.smk_size = smk_size\n        self.num_heads = num_heads * 2\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n        self.is_first = is_first\n        self.is_last = is_last\n        self.use_checkpoint = use_checkpoint\n\n        if not is_first:\n            self.x_scale = LayerScale(ctx_dim, init_value=1)\n            self.h_scale = LayerScale(ctx_dim, init_value=1)\n        \n        self.dwconv1 = ResDWConv(out_dim, kernel_size=3)\n        self.norm1 = norm_layer(out_dim)\n        \n        self.fusion = nn.Sequential(\n            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),\n            nn.BatchNorm2d(out_dim),\n            nn.GELU(),\n            nn.Conv2d(out_dim, dim, kernel_size=1),\n            GRN(dim),\n        )\n        \n        self.weight_query = nn.Sequential(\n            nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n         \n        self.weight_key = nn.Sequential(\n            nn.AdaptiveAvgPool2d(7),\n            nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n        \n        self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)\n        \n        self.dyconv_proj = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.lepe = nn.Sequential(\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.se_layer = SEModule(dim)\n        \n        self.gate = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n            nn.SiLU(),\n        )\n\n        self.proj = nn.Sequential(\n            nn.BatchNorm2d(dim),\n            nn.Conv2d(dim, out_dim, kernel_size=1),\n        )\n        \n        self.dwconv2 = ResDWConv(out_dim, kernel_size=3)\n        self.norm2 = norm_layer(out_dim)\n        \n        self.mlp = nn.Sequential(\n            nn.Conv2d(out_dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, out_dim, kernel_size=1),\n        )\n        \n        self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()\n        \n        self.get_rpb()\n\n\n    def get_rpb(self):\n        self.rpb_size1 = 2 * self.smk_size - 1\n        self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))\n        self.rpb_size2 = 2 * self.kernel_size - 1\n        self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))\n        nn.init.zeros_(self.rpb1)\n        nn.init.zeros_(self.rpb2)\n    \n        \n    @torch.no_grad()\n    def generate_idx(self, kernel_size):\n        rpb_size = 2 * kernel_size - 1\n        idx_h = torch.arange(0, kernel_size)\n        idx_w = torch.arange(0, kernel_size)\n        idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)\n        return (idx_h, idx_w, idx_k)\n    \n\n    def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):\n        \"\"\"\n        RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3\n        \"\"\"\n        num_repeat_h = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_w = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_h[kernel_size//2] = height - (kernel_size-1)\n        num_repeat_w[kernel_size//2] = width - (kernel_size-1)\n        bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)\n        bias_idx = bias_hw.unsqueeze(-1) + idx_k\n        bias_idx = bias_idx.reshape(-1, int(kernel_size**2))\n        bias_idx = torch.flip(bias_idx, [0])\n        rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]\n        rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))\n        return attn + rpb\n    \n\n    def _forward_inner(self, x, h_x, h_r):\n        input_resoltion = x.shape[2:]   \n        B, C, H, W = x.shape\n        B, C_h, H_h, W_h = h_x.shape\n        \n        if not self.is_first:\n            h_x = self.x_scale(h_x) + self.h_scale(h_r)\n\n        x_f = torch.cat([x, h_x], dim=1)\n        x_f = self.dwconv1(x_f)\n        identity = x_f\n        x_f = self.norm1(x_f)\n        x = self.fusion(x_f)\n        gate = self.gate(x)\n        lepe = self.lepe(x)\n\n        is_pad = False\n        if min(H, W) < self.kernel_size:\n            is_pad = True\n            if H < W:\n                size = (self.kernel_size, int(self.kernel_size / H * W))\n            else:\n                size = (int(self.kernel_size / W * H), self.kernel_size)\n            x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)\n            x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)\n            H, W = size\n\n        query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)\n        query = self.weight_query(query) * self.scale\n        key = self.weight_key(key)\n        query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        weight = einsum(query, key, 'b g c n, b g c l -> b g n l')\n        weight = rearrange(weight, 'b g n l -> b l g n').contiguous()\n        weight = self.weight_proj(weight)\n        weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)\n\n        attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)\n        rpb1_idx = self.generate_idx(self.smk_size)\n        rpb2_idx = self.generate_idx(self.kernel_size)\n        attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)\n        attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)\n        attn1 = torch.softmax(attn1, dim=-1)\n        attn2 = torch.softmax(attn2, dim=-1)\n        value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)\n\n        x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)\n        x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)\n\n        x = torch.cat([x1, x2], dim=1)\n        x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)\n\n        if is_pad:\n            x = F.adaptive_avg_pool2d(x, input_resoltion)\n\n        x = self.dyconv_proj(x)\n\n        x = x + lepe\n        x = self.se_layer(x)\n\n        x = gate * x\n        x = self.proj(x)\n\n        if self.res_scale:\n            x = self.ls1(identity) + self.drop_path(x)\n        else:\n            x = identity + self.drop_path(self.ls1(x))\n        \n        x = self.dwconv2(x)\n         \n        if self.res_scale:\n            x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n\n        if self.is_last:\n            return (x, None)\n        else:\n            l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)\n            return (l_x, h_x)\n    \n    def forward(self, x, h_x, h_r):\n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)\n        else:\n            x = self._forward_inner(x, h_x, h_r)\n        return x\n\n\nclass OverLoCK(nn.Module):\n    '''\n    An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels\n    https://arxiv.org/abs/2502.20087\n    '''\n    def __init__(self, \n                 depth=[2, 2, 2, 2],\n                 sub_depth=[4, 2],\n                 in_chans=3, \n                 embed_dim=[96, 192, 384, 768],\n                 kernel_size=[7, 7, 7, 7],\n                 mlp_ratio=[4, 4, 4, 4],\n                 sub_mlp_ratio=[4, 4],\n                 sub_num_heads=[4, 8],\n                 ls_init_value=[None, None, 1, 1],\n                 res_scale=True,\n                 smk_size=5,\n                 deploy=False,\n                 use_gemm=True,\n                 use_ds=True,\n                 drop_rate=0,\n                 drop_path_rate=0,\n                 norm_layer=LayerNorm2d,\n                 projection=1024,\n                 num_classes=1000,\n                 use_checkpoint=[0, 0, 0, 0],\n            ):\n \n        super().__init__()\n        \n        fusion_dim = embed_dim[-1] + embed_dim[-1]//4\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim\n\n        self.patch_embed1 = stem(in_chans, embed_dim[0])\n        self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])\n        self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])\n        self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])\n        self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)\n        self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]\n\n        self.blocks1 = nn.ModuleList()\n        self.blocks2 = nn.ModuleList()\n        self.blocks3 = nn.ModuleList()\n        self.blocks4 = nn.ModuleList()\n        self.sub_blocks3 = nn.ModuleList()\n        self.sub_blocks4 = nn.ModuleList()\n        \n        for i in range(depth[0]):\n            self.blocks1.append(\n                RepConvBlock(\n                    dim=embed_dim[0],\n                    kernel_size=kernel_size[0],\n                    mlp_ratio=mlp_ratio[0],\n                    ls_init_value=ls_init_value[0],\n                    res_scale=res_scale,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[0]),\n                )\n            )\n        \n        for i in range(depth[1]):\n            self.blocks2.append(\n                RepConvBlock(\n                    dim=embed_dim[1],\n                    kernel_size=kernel_size[1],\n                    mlp_ratio=mlp_ratio[1],\n                    ls_init_value=ls_init_value[1],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+depth[0]],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[1]),\n                )\n            )\n            \n        for i in range(depth[2]):\n            self.blocks3.append(\n                RepConvBlock(\n                    dim=embed_dim[2],\n                    kernel_size=kernel_size[2],\n                    mlp_ratio=mlp_ratio[2],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:2])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n\n        for i in range(depth[3]):\n            self.blocks4.append(\n                RepConvBlock(\n                    dim=embed_dim[3],\n                    kernel_size=kernel_size[3],\n                    mlp_ratio=mlp_ratio[3],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:3])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n            \n        for i in range(sub_depth[0]):\n            self.sub_blocks3.append(\n                DynamicConvBlock(\n                    dim=embed_dim[2],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[2],\n                    num_heads=sub_num_heads[0],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[0],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    is_first=(i==0),\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n        \n        for i in range(sub_depth[1]):\n            self.sub_blocks4.append(\n                DynamicConvBlock(\n                    dim=embed_dim[3],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[-1],\n                    num_heads=sub_num_heads[1],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[1],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)+sub_depth[0]],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    is_first=False,\n                    is_last=(i==sub_depth[1]-1),\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n\n        # Aux Cls Head\n        if use_ds:\n            self.aux_head = nn.Sequential(\n                nn.BatchNorm2d(embed_dim[-1]),\n                nn.AdaptiveAvgPool2d(1),\n                nn.Conv2d(embed_dim[-1], num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n            )\n        \n        # Main Cls Head\n        self.head = nn.Sequential(\n            nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),\n            nn.BatchNorm2d(projection),\n            nn.SiLU(),\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n        )\n        \n        self.apply(self._init_weights)\n        \n        if torch.distributed.is_initialized():\n            self = nn.SyncBatchNorm.convert_sync_batchnorm(self)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):\n            nn.init.trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):\n            nn.init.constant_(m.weight, 1.0)\n            nn.init.constant_(m.bias, 0)\n    \n    def reparam(self):\n        for m in self.modules():\n            if isinstance(m, DilatedReparamBlock):\n                m.merge_dilated_branches()\n            \n    def forward_pre_features(self, x):\n        \n        x = self.patch_embed1(x)\n        for blk in self.blocks1:\n            x = blk(x)\n            \n        x = self.patch_embed2(x)\n        for blk in self.blocks2:\n            x = blk(x)\n\n        return x\n    \n    \n    def forward_base_features(self, x):\n        \n        x = self.patch_embed3(x)\n        for blk in self.blocks3:\n            x = blk(x)\n            \n        ctx = self.patch_embed4(x)\n        for blk in self.blocks4:\n            ctx = blk(ctx)\n\n        return (x, ctx)\n    \n\n    def forward_sub_features(self, x, ctx):\n\n        ctx_cls = ctx\n        ctx_ori = self.high_level_proj(ctx)\n        ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)\n        \n        for idx, blk in enumerate(self.sub_blocks3):\n            if idx == 0:\n                ctx = ctx_up\n            x, ctx = blk(x, ctx, ctx_up)\n\n        x, ctx = self.patch_embedx(x, ctx)\n        for idx, blk in enumerate(self.sub_blocks4):\n            x, ctx = blk(x, ctx, ctx_ori)\n        \n        return (x, ctx_cls)\n\n    def forward_features(self, x):\n        \n        x = self.forward_pre_features(x)\n        x, ctx = self.forward_base_features(x)\n        x, ctx_cls = self.forward_sub_features(x, ctx)\n\n        return (x, ctx_cls)\n\n    def forward(self, x):\n        \n        x, ctx = self.forward_features(x)\n        x = self.head(x).flatten(1)\n\n        if hasattr(self, 'aux_head') and self.training:\n            ctx = self.aux_head(ctx).flatten(1)\n            return dict(main=x, aux=ctx)\n        \n        return x\n\n\ndef _cfg(url=None, **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000,\n        'input_size': (3, 224, 224),\n        'crop_pct': 0.9,\n        'interpolation': 'bicubic',  # 'bilinear' or 'bicubic'\n        'mean': timm.data.IMAGENET_DEFAULT_MEAN,\n        'std': timm.data.IMAGENET_DEFAULT_STD,\n        'classifier': 'classifier',\n        **kwargs,\n    }\n\n\n@register_model\ndef overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[2, 2, 3, 2],\n        sub_depth=[6, 2],\n        embed_dim=[56, 112, 256, 336],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 6],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    model.default_cfg = _cfg(crop_pct=0.925)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n@register_model\ndef overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[4, 4, 6, 2],\n        sub_depth=[12, 2],\n        embed_dim=[64, 128, 256, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 8],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n    \n    model.default_cfg = _cfg(crop_pct=0.95)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n@register_model\ndef overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[6, 6, 8, 3],\n        sub_depth=[16, 3],\n        embed_dim=[64, 128, 320, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[8, 16],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    model.default_cfg = _cfg(crop_pct=0.95)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n@register_model\ndef overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[8, 8, 10, 4],\n        sub_depth=[20, 4],\n        embed_dim=[80, 160, 384, 576],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[6, 9],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n    \n    model.default_cfg = _cfg(crop_pct=0.975)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n'''\nReparameterized versions of OverLoCK models are given, \nwhich offer improved efficiency in terms of inference speed and memory consumption.\n\nNote: these variants may come at the cost of lower accuracy *during fine-tuning*, \nwhen compared to their original counterparts.\n\nMore details about model reparameterization can be found at:\nhttps://arxiv.org/abs/2311.15599\n'''\n@register_model\ndef overlock_xt_reparam(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = overlock_xt(deploy=True)\n    model.default_cfg = _cfg(crop_pct=0.925)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224_reparam.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n@register_model\ndef overlock_t_reparam(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = overlock_t(deploy=True)\n    model.default_cfg = _cfg(crop_pct=0.95)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224_reparam.pth'\n        load_checkpoint(model, pretrained)\n\n    return model\n\n\n@register_model\ndef overlock_s_reparam(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = overlock_s(deploy=True)\n    model.default_cfg = _cfg(crop_pct=0.95)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224_reparam.pth'\n    \n    return model\n\n\n@register_model\ndef overlock_b_reparam(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = overlock_b(deploy=True)\n    model.default_cfg = _cfg(crop_pct=0.975)\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224_reparam.pth'\n    \n    return model\n\n\nif __name__ == '__main__':\n\n    device = torch.device('cuda')\n    model = overlock_xt(pretrained=True).to(device) # load pretrained weights\n    model.eval()\n\n    x = torch.randn(1, 3, 224, 224).to(device)\n    y = model(x)\n    print(y.shape)\n\n    # Reparametrized model, more details can be found at: \n    # https://github.com/AILab-CVC/UniRepLKNet/tree/main\n    model = overlock_xt_reparam(pretrained=True).to(device) # load pretrained weights\n    model.eval()\n    z = model(x)\n    print(z.shape)\n\n    # Showing difference between original and reparametrized model\n    print((y-z).mean())\n"
  },
  {
    "path": "scripts/train_b_model.sh",
    "content": "#!/usr/bin/env bash\npython3 -m torch.distributed.launch \\\n--master_port=$((RANDOM+8888)) \\\n--nproc_per_node=8 \\\ntrain.py \\\n--data-dir /data/dataset/imagenet/ \\\n--batch-size 256 \\\n--model overlock_b \\\n--lr 1e-3 \\\n--auto-lr \\\n--drop-path 0.5 \\\n--epochs 300 \\\n--warmup-epochs 5 \\\n--workers 10 \\\n--model-ema \\\n--model-ema-decay 0.9999 \\\n--output output/overlock_b/ \\\n--native-amp \\\n--clip-grad 5"
  },
  {
    "path": "scripts/train_s_model.sh",
    "content": "#!/usr/bin/env bash\npython3 -m torch.distributed.launch \\\n--master_port=$((RANDOM+8888)) \\\n--nproc_per_node=8 \\\ntrain.py \\\n--data-dir /data/dataset/imagenet/ \\\n--batch-size 256 \\\n--model overlock_s \\\n--lr 1e-3 \\\n--auto-lr \\\n--drop-path 0.4 \\\n--epochs 300 \\\n--warmup-epochs 5 \\\n--workers 10 \\\n--model-ema \\\n--model-ema-decay 0.9999 \\\n--output output/overlock_s/ \\\n--native-amp \\\n--clip-grad 5"
  },
  {
    "path": "scripts/train_t_model.sh",
    "content": "#!/usr/bin/env bash\npython3 -m torch.distributed.launch \\\n--master_port=$((RANDOM+8888)) \\\n--nproc_per_node=8 \\\ntrain.py \\\n--data-dir /data/dataset/imagenet/ \\\n--batch-size 256 \\\n--model overlock_t \\\n--lr 1e-3 \\\n--auto-lr \\\n--drop-path 0.15 \\\n--epochs 300 \\\n--warmup-epochs 5 \\\n--workers 10 \\\n--model-ema \\\n--model-ema-decay 0.9999 \\\n--output output/overlock_t/ \\\n--native-amp \\\n--clip-grad 5"
  },
  {
    "path": "scripts/train_xt_model.sh",
    "content": "#!/usr/bin/env bash\npython3 -m torch.distributed.launch \\\n--master_port=$((RANDOM+8888)) \\\n--nproc_per_node=8 \\\ntrain.py \\\n--data-dir /data/dataset/imagenet/ \\\n--batch-size 256 \\\n--model overlock_xt \\\n--lr 1e-3 \\\n--auto-lr \\\n--drop-path 0.1 \\\n--epochs 300 \\\n--warmup-epochs 5 \\\n--workers 10 \\\n--model-ema \\\n--model-ema-decay 0.9999 \\\n--output output/overlock_xt/ \\\n--native-amp \\\n--clip-grad 5"
  },
  {
    "path": "segmentation/configs/_base_/datasets/ade20k.py",
    "content": "# dataset settings\ndataset_type = 'ADE20KDataset'\ndata_root = '/grp01/cs_yzyu/dataset/ADEChallengeData2016/'\nimg_norm_cfg = dict(\n    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)\ncrop_size = (512, 512)\ntrain_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),\n    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),\n    dict(type='RandomFlip', prob=0.5),\n    dict(type='PhotoMetricDistortion'),\n    dict(type='Normalize', **img_norm_cfg),\n    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),\n    dict(type='DefaultFormatBundle'),\n    dict(type='Collect', keys=['img', 'gt_semantic_seg']),\n]\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(\n        type='MultiScaleFlipAug',\n        img_scale=(2048, 512),\n        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],\n        flip=False,\n        transforms=[\n            # dict(type='AlignResize', keep_ratio=True, size_divisor=32),\n            dict(type='Resize', keep_ratio=True),\n            dict(type='ResizeToMultiple', size_divisor=32, interpolation='bicubic'),\n            dict(type='RandomFlip'),\n            dict(type='Normalize', **img_norm_cfg),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='Collect', keys=['img']),\n        ])\n]\ndata = dict(\n    samples_per_gpu=4,\n    workers_per_gpu=4,\n    train=dict(\n        type='RepeatDataset',\n        times=50,\n        dataset=dict(\n            type=dataset_type,\n            data_root=data_root,\n            img_dir='images/training',\n            ann_dir='annotations/training',\n            pipeline=train_pipeline),\n    ),\n    val=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline),\n    test=dict(\n        type=dataset_type,\n        data_root=data_root,\n        img_dir='images/validation',\n        ann_dir='annotations/validation',\n        pipeline=test_pipeline))"
  },
  {
    "path": "segmentation/configs/_base_/default_runtime.py",
    "content": "# yapf:disable\nlog_config = dict(\n    interval=50,\n    hooks=[\n        dict(type='TextLoggerHook', by_epoch=False),\n        # dict(type='TensorboardLoggerHook')\n    ])\n# yapf:enable\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\ncudnn_benchmark = True\n"
  },
  {
    "path": "segmentation/configs/_base_/models/fpn_r50.py",
    "content": "# copied from mmsegmentaion official config\n# https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/fpn_r50.py\n\n\n# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 1, 1),\n        strides=(1, 2, 2, 2),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=256,\n        num_outs=4),\n    decode_head=dict(\n        type='FPNHead',\n        in_channels=[256, 256, 256, 256],\n        in_index=[0, 1, 2, 3],\n        feature_strides=[4, 8, 16, 32],\n        channels=128,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))\n    "
  },
  {
    "path": "segmentation/configs/_base_/models/upernet_r50.py",
    "content": "# copied from mmsegmentation official config\n# https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_r50.py\n\n# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained='open-mmlab://resnet50_v1c',\n    backbone=dict(\n        type='ResNetV1c',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        dilations=(1, 1, 1, 1),\n        strides=(1, 2, 2, 2),\n        norm_cfg=norm_cfg,\n        norm_eval=False,\n        style='pytorch',\n        contract_dilation=True),\n    decode_head=dict(\n        type='UPerHead',\n        in_channels=[256, 512, 1024, 2048],\n        in_index=[0, 1, 2, 3],\n        pool_scales=(1, 2, 3, 6),\n        channels=512,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=1024,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)\n    ),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))"
  },
  {
    "path": "segmentation/configs/_base_/models/upernet_transnext.py",
    "content": "# model settings\nnorm_cfg = dict(type='SyncBN', requires_grad=True)\nmodel = dict(\n    type='EncoderDecoder',\n    pretrained=None,\n    decode_head=dict(\n        type='UPerHead',\n        in_channels=[96, 192, 384, 768],\n        in_index=[0, 1, 2, 3],\n        pool_scales=(1, 2, 3, 6),\n        channels=512,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),\n    auxiliary_head=dict(\n        type='FCNHead',\n        in_channels=384,\n        in_index=2,\n        channels=256,\n        num_convs=1,\n        concat_input=False,\n        dropout_ratio=0.1,\n        num_classes=19,\n        norm_cfg=norm_cfg,\n        align_corners=False,\n        loss_decode=dict(\n            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),\n    # model training and testing settings\n    train_cfg=dict(),\n    test_cfg=dict(mode='whole'))"
  },
  {
    "path": "segmentation/configs/_base_/schedules/schedule_160k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=160000)\ncheckpoint_config = dict(by_epoch=False, interval=16000)\nevaluation = dict(interval=16000, metric='mIoU')\n"
  },
  {
    "path": "segmentation/configs/_base_/schedules/schedule_20k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=20000)\ncheckpoint_config = dict(by_epoch=False, interval=2000)\nevaluation = dict(interval=2000, metric='mIoU')\n"
  },
  {
    "path": "segmentation/configs/_base_/schedules/schedule_40k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=40000)\ncheckpoint_config = dict(by_epoch=False, interval=4000)\nevaluation = dict(interval=4000, metric='mIoU')\n"
  },
  {
    "path": "segmentation/configs/_base_/schedules/schedule_80k.py",
    "content": "# optimizer\noptimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)\noptimizer_config = dict()\n# learning policy\nlr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)\n# runtime settings\nrunner = dict(type='IterBasedRunner', max_iters=80000)\ncheckpoint_config = dict(by_epoch=False, interval=8000)\nevaluation = dict(interval=8000, metric='mIoU')\n"
  },
  {
    "path": "segmentation/configs/overlock/upernet_overlock_base_ade20k_8xb2.py",
    "content": "_base_ = [\n    '../_base_/models/upernet_r50.py',\n    '../_base_/datasets/ade20k.py',\n    '../_base_/default_runtime.py',\n    '../_base_/schedules/schedule_160k.py'\n]\n\n\n\nmodel = dict(\n    pretrained=None,\n    backbone=dict(\n        _delete_=True,\n        type='overlock_b',\n        pretrained=True,\n        drop_path_rate=0.5,\n    ),\n    decode_head=dict(\n        in_index=[0, 1, 2, 3],\n        in_channels=[80, 160, 528, 720],\n        num_classes=150,\n    ),\n    auxiliary_head=dict(\n        in_index=2,\n        in_channels=528,\n        num_classes=150\n    ),\n)\n\n# optimizer\noptimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nlr_config = dict(_delete_=True,\n                 policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0,\n                 min_lr=0.0,\n                 by_epoch=False)\n\n\ndata = dict(samples_per_gpu=2) # as gpus = 8\ncheckpoint_config = dict(interval=8000, max_keep_ckpts=1)\nevaluation = dict(interval=8000, save_best='mIoU')\n\n# place holder for new verison mmseg compatiability\nresume_from = None\ndevice = 'cuda'\n\n# # AMP (faster but may meet nan loss) ->\n# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')\n# fp16 = dict()"
  },
  {
    "path": "segmentation/configs/overlock/upernet_overlock_small_ade20k_8xb2.py",
    "content": "_base_ = [\n    '../_base_/models/upernet_r50.py',\n    '../_base_/datasets/ade20k.py',\n    '../_base_/default_runtime.py',\n    '../_base_/schedules/schedule_160k.py'\n]\n\n\nmodel = dict(\n    pretrained=None,\n    backbone=dict(\n        _delete_=True,\n        type='overlock_s',\n        pretrained=True,\n        drop_path_rate=0.3,\n    ),\n    decode_head=dict(\n        in_index=[0, 1, 2, 3],\n        in_channels=[64, 128, 448, 640],\n        num_classes=150,\n    ),\n    auxiliary_head=dict(\n        in_index=2,\n        in_channels=448,\n        num_classes=150\n    ),\n)\n\n# optimizer\noptimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nlr_config = dict(_delete_=True,\n                 policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0,\n                 min_lr=0.0,\n                 by_epoch=False)\n\n\ndata = dict(samples_per_gpu=2) # as gpus = 8\ncheckpoint_config = dict(interval=8000, max_keep_ckpts=1)\nevaluation = dict(interval=8000, save_best='mIoU')\n\n# place holder for new verison mmseg compatiability\nresume_from = None\ndevice = 'cuda'\n\n# # AMP (faster but may meet nan loss) ->\n# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')\n# fp16 = dict()"
  },
  {
    "path": "segmentation/configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py",
    "content": "_base_ = [\n    '../_base_/models/upernet_r50.py',\n    '../_base_/datasets/ade20k.py',\n    '../_base_/default_runtime.py',\n    '../_base_/schedules/schedule_160k.py'\n]\n\n\nmodel = dict(\n    pretrained=None,\n    backbone=dict(\n        _delete_=True,\n        type='overlock_t',\n        pretrained=True,\n        drop_path_rate=0.2\n    ),\n    decode_head=dict(\n        in_index=[0, 1, 2, 3],\n        in_channels=[64, 128, 384, 640],\n        num_classes=150,\n    ),\n    auxiliary_head=dict(\n        in_index=2,\n        in_channels=384,\n        num_classes=150\n    ),\n)\n\n# optimizer\noptimizer = dict(_delete_=True, type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01,\n                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),\n                                                 'relative_position_bias_table': dict(decay_mult=0.),\n                                                 'norm': dict(decay_mult=0.)}))\n\nlr_config = dict(_delete_=True, \n                 policy='poly',\n                 warmup='linear',\n                 warmup_iters=1500,\n                 warmup_ratio=1e-6,\n                 power=1.0, \n                 min_lr=0.0, \n                 by_epoch=False)\n\n\ndata = dict(samples_per_gpu=2) # as gpus = 8\ncheckpoint_config = dict(interval=8000, max_keep_ckpts=1)\nevaluation = dict(interval=8000, save_best='mIoU')\n\n# place holder for new verison mmseg compatiability\nresume_from = None\ndevice = 'cuda'\n\n# # AMP (faster but may meet nan loss) ->\n# optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')\n# fp16 = dict()"
  },
  {
    "path": "segmentation/mmseg_custom/__init__.py",
    "content": "from .align_resize import AlignResize"
  },
  {
    "path": "segmentation/mmseg_custom/align_resize.py",
    "content": "import mmcv\nimport numpy as np\nfrom mmseg.datasets.builder import PIPELINES\n\n# from IPython import embed\n# from numpy import random\n# from mmcv.utils import deprecated_api_warning, is_tuple_of\n\n@PIPELINES.register_module()\nclass AlignResize(object):\n    \"\"\"Resize images & seg. Align\"\"\"\n\n    def __init__(self,\n                 img_scale=None,\n                 multiscale_mode='range',\n                 ratio_range=None,\n                 keep_ratio=True,\n                 size_divisor=32,\n                 interpolation=None):\n        if img_scale is None:\n            self.img_scale = None\n        else:\n            if isinstance(img_scale, list):\n                self.img_scale = img_scale\n            else:\n                self.img_scale = [img_scale]\n            assert mmcv.is_list_of(self.img_scale, tuple)\n\n        if ratio_range is not None:\n            # mode 1: given img_scale=None and a range of image ratio\n            # mode 2: given a scale and a range of image ratio\n            assert self.img_scale is None or len(self.img_scale) == 1\n        else:\n            # mode 3 and 4: given multiple scales or a range of scales\n            assert multiscale_mode in ['value', 'range']\n\n        self.multiscale_mode = multiscale_mode\n        self.ratio_range = ratio_range\n        self.keep_ratio = keep_ratio\n        self.size_divisor = size_divisor\n        self.interpolation = interpolation\n\n    @staticmethod\n    def random_select(img_scales):\n        \"\"\"Randomly select an img_scale from given candidates.\n\n        Args:\n            img_scales (list[tuple]): Images scales for selection.\n\n        Returns:\n            (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,\n                where ``img_scale`` is the selected image scale and\n                ``scale_idx`` is the selected index in the given candidates.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple)\n        scale_idx = np.random.randint(len(img_scales))\n        img_scale = img_scales[scale_idx]\n        return img_scale, scale_idx\n\n    @staticmethod\n    def random_sample(img_scales):\n        \"\"\"Randomly sample an img_scale when ``multiscale_mode=='range'``.\n\n        Args:\n            img_scales (list[tuple]): Images scale range for sampling.\n                There must be two tuples in img_scales, which specify the lower\n                and uper bound of image scales.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(img_scale, None)``, where\n                ``img_scale`` is sampled scale and None is just a placeholder\n                to be consistent with :func:`random_select`.\n        \"\"\"\n\n        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2\n        img_scale_long = [max(s) for s in img_scales]\n        img_scale_short = [min(s) for s in img_scales]\n        long_edge = np.random.randint(\n            min(img_scale_long),\n            max(img_scale_long) + 1)\n        short_edge = np.random.randint(\n            min(img_scale_short),\n            max(img_scale_short) + 1)\n        img_scale = (long_edge, short_edge)\n        return img_scale, None\n\n    @staticmethod\n    def random_sample_ratio(img_scale, ratio_range):\n        \"\"\"Randomly sample an img_scale when ``ratio_range`` is specified.\n\n        A ratio will be randomly sampled from the range specified by\n        ``ratio_range``. Then it would be multiplied with ``img_scale`` to\n        generate sampled scale.\n\n        Args:\n            img_scale (tuple): Images scale base to multiply with ratio.\n            ratio_range (tuple[float]): The minimum and maximum ratio to scale\n                the ``img_scale``.\n\n        Returns:\n            (tuple, None): Returns a tuple ``(scale, None)``, where\n                ``scale`` is sampled ratio multiplied with ``img_scale`` and\n                None is just a placeholder to be consistent with\n                :func:`random_select`.\n        \"\"\"\n\n        assert isinstance(img_scale, tuple) and len(img_scale) == 2\n        min_ratio, max_ratio = ratio_range\n        assert min_ratio <= max_ratio\n        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio\n        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)\n        return scale, None\n\n    def _random_scale(self, results):\n        \"\"\"Randomly sample an img_scale according to ``ratio_range`` and\n        ``multiscale_mode``.\n\n        If ``ratio_range`` is specified, a ratio will be sampled and be\n        multiplied with ``img_scale``.\n        If multiple scales are specified by ``img_scale``, a scale will be\n        sampled according to ``multiscale_mode``.\n        Otherwise, single scale will be used.\n\n        Args:\n            results (dict): Result dict from :obj:`dataset`.\n\n        Returns:\n            dict: Two new keys 'scale` and 'scale_idx` are added into\n                ``results``, which would be used by subsequent pipelines.\n        \"\"\"\n\n        if self.ratio_range is not None:\n            if self.img_scale is None:\n                h, w = results['img'].shape[:2]\n                scale, scale_idx = self.random_sample_ratio((w, h),\n                                                            self.ratio_range)\n            else:\n                scale, scale_idx = self.random_sample_ratio(\n                    self.img_scale[0], self.ratio_range)\n        elif len(self.img_scale) == 1:\n            scale, scale_idx = self.img_scale[0], 0\n        elif self.multiscale_mode == 'range':\n            scale, scale_idx = self.random_sample(self.img_scale)\n        elif self.multiscale_mode == 'value':\n            scale, scale_idx = self.random_select(self.img_scale)\n        else:\n            raise NotImplementedError\n\n        results['scale'] = scale\n        results['scale_idx'] = scale_idx\n\n    def _align(self, img, size_divisor, interpolation=None):\n        align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor\n        align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor\n        if interpolation == None:\n            img = mmcv.imresize(img, (align_w, align_h))\n        else:\n            img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation)\n        return img\n\n    def _resize_img(self, results):\n        \"\"\"Resize images with ``results['scale']``.\"\"\"\n        if self.keep_ratio:\n            img, scale_factor = mmcv.imrescale(\n                results['img'], results['scale'], return_scale=True)\n            #### align ####\n            img = self._align(img, self.size_divisor, interpolation='bicubic')\n            # the w_scale and h_scale has minor difference\n            # a real fix should be done in the mmcv.imrescale in the future\n            new_h, new_w = img.shape[:2]\n            h, w = results['img'].shape[:2]\n            w_scale = new_w / w\n            h_scale = new_h / h\n        else:\n            img, w_scale, h_scale = mmcv.imresize(\n                results['img'], results['scale'], return_scale=True)\n\n            h, w = img.shape[:2]\n            assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \\\n                   int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \\\n                   \"img size not align. h:{} w:{}\".format(h,w)\n        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],\n                                dtype=np.float32)\n        results['img'] = img\n        results['img_shape'] = img.shape\n        results['pad_shape'] = img.shape  # in case that there is no padding\n        results['scale_factor'] = scale_factor\n        results['keep_ratio'] = self.keep_ratio\n\n    def _resize_seg(self, results):\n        \"\"\"Resize semantic segmentation map with ``results['scale']``.\"\"\"\n        for key in results.get('seg_fields', []):\n            if self.keep_ratio:\n                gt_seg = mmcv.imrescale(\n                    results[key], results['scale'], interpolation='nearest')\n                gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest')\n            else:\n                gt_seg = mmcv.imresize(\n                    results[key], results['scale'], interpolation='nearest')\n                h, w = gt_seg.shape[:2]\n                assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \\\n                       int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \\\n                    \"gt_seg size not align. h:{} w:{}\".format(h, w)\n            results[key] = gt_seg\n\n    def __call__(self, results):\n        \"\"\"Call function to resize images, bounding boxes, masks, semantic\n        segmentation map.\n\n        Args:\n            results (dict): Result dict from loading pipeline.\n\n        Returns:\n            dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',\n                'keep_ratio' keys are added into result dict.\n        \"\"\"\n\n        if 'scale' not in results:\n            self._random_scale(results)\n        self._resize_img(results)\n        self._resize_seg(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += (f'(img_scale={self.img_scale}, '\n                     f'multiscale_mode={self.multiscale_mode}, '\n                     f'ratio_range={self.ratio_range}, '\n                     f'keep_ratio={self.keep_ratio})')\n        return repr_str"
  },
  {
    "path": "segmentation/models/__init__.py",
    "content": "from .overlock import *"
  },
  {
    "path": "segmentation/models/overlock.py",
    "content": "'''\nThis is an official implementation of OverLoCK model proposed in the paper: \nhttps://arxiv.org/abs/2502.20087\n'''\nimport torch\nimport timm\nimport torch.distributed\nimport torch.nn.functional as F\nfrom torch import nn\nfrom einops import rearrange, einsum\nfrom natten.functional import na2d_av\nfrom torch.utils.checkpoint import checkpoint\nfrom timm.models.layers import DropPath, to_2tuple\nfrom timm.models.registry import register_model\nfrom mmseg.models.builder import MODELS\nfrom mmseg.utils import get_root_logger\nfrom mmcv.runner import load_checkpoint\n\ndef get_conv2d(in_channels, \n               out_channels, \n               kernel_size, \n               stride, \n               padding, \n               dilation, \n               groups, \n               bias,\n               attempt_use_lk_impl=True):\n    \n    kernel_size = to_2tuple(kernel_size)\n    if padding is None:\n        padding = (kernel_size[0] // 2, kernel_size[1] // 2)\n    else:\n        padding = to_2tuple(padding)\n    need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)\n\n    if attempt_use_lk_impl and need_large_impl:\n        print('---------------- trying to import iGEMM implementation for large-kernel conv')\n        try:\n            from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM\n            print('---------------- found iGEMM implementation ')\n        except:\n            DepthWiseConv2dImplicitGEMM = None\n            print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')\n        if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \\\n                and out_channels == groups and stride == 1 and dilation == 1:\n            print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')\n            return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)\n    \n    return nn.Conv2d(in_channels, out_channels, \n                     kernel_size=kernel_size, \n                     stride=stride,\n                     padding=padding, \n                     dilation=dilation, \n                     groups=groups, \n                     bias=bias)\n\n\ndef get_bn(dim, use_sync_bn=False):\n    if use_sync_bn:\n        return nn.SyncBatchNorm(dim)\n    else:\n        return nn.BatchNorm2d(dim)\n\n\ndef fuse_bn(conv, bn):\n    conv_bias = 0 if conv.bias is None else conv.bias\n    std = (bn.running_var + bn.eps).sqrt()\n    return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std\n\ndef convert_dilated_to_nondilated(kernel, dilate_rate):\n    identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)\n    if kernel.size(1) == 1:\n        #   This is a DW kernel\n        dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)\n        return dilated\n    else:\n        #   This is a dense or group-wise (but not DW) kernel\n        slices = []\n        for i in range(kernel.size(1)):\n            dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)\n            slices.append(dilated)\n        return torch.cat(slices, dim=1)\n\ndef merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):\n    large_k = large_kernel.size(2)\n    dilated_k = dilated_kernel.size(2)\n    equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1\n    equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)\n    rows_to_pad = large_k // 2 - equivalent_kernel_size // 2\n    merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)\n    return merged_kernel\n\n\ndef stem(in_chans=3, embed_dim=96):\n    return nn.Sequential(\n        nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim//2),\n        nn.GELU(),\n        nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim),\n        nn.GELU(),\n        nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),\n        nn.BatchNorm2d(embed_dim)\n    )\n\n\ndef downsample(in_dim, out_dim):\n    return nn.Sequential(\n        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False),\n        nn.BatchNorm2d(out_dim),\n    )        \n\n\nclass SEModule(nn.Module):\n    def __init__(self, dim, red=8, inner_act=nn.GELU, out_act=nn.Sigmoid):\n        super().__init__()\n        inner_dim = max(16, dim // red)\n        self.proj = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(dim, inner_dim, kernel_size=1),\n            inner_act(),\n            nn.Conv2d(inner_dim, dim, kernel_size=1),\n            out_act(),\n        )\n        \n    def forward(self, x):\n        x = x * self.proj(x)\n        return x\n\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_value=1e-5):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value, \n                                   requires_grad=True)\n        self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)\n\n    def forward(self, x):\n        x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])\n        return x\n\n        \nclass LayerNorm2d(nn.LayerNorm):\n    def __init__(self, dim):\n        super().__init__(normalized_shape=dim, eps=1e-6)\n    \n    def forward(self, x):\n        x = rearrange(x, 'b c h w -> b h w c')\n        x = super().forward(x)\n        x = rearrange(x, 'b h w c -> b c h w')\n        return x.contiguous()\n\n\nclass GRN(nn.Module):\n    \"\"\" GRN (Global Response Normalization) layer\n    Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)\n    This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)\n    We assume the inputs to this layer are (N, C, H, W)\n    \"\"\"\n    def __init__(self, dim, use_bias=True):\n        super().__init__()\n        self.use_bias = use_bias\n        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))\n        if self.use_bias:\n            self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))\n\n    def forward(self, x):\n        Gx = torch.norm(x, p=2, dim=(-1, -2), keepdim=True)\n        Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)\n        if self.use_bias:\n            return (self.gamma * Nx + 1) * x + self.beta\n        else:\n            return (self.gamma * Nx + 1) * x\n    \n\n\nclass DilatedReparamBlock(nn.Module):\n    \"\"\"\n    Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)\n    We assume the inputs to this block are (N, C, H, W)\n    \"\"\"\n    def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):\n        super().__init__()\n        self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,\n                                    padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,\n                                    attempt_use_lk_impl=attempt_use_lk_impl)\n        self.attempt_use_lk_impl = attempt_use_lk_impl\n\n        #   Default settings. We did not tune them carefully. Different settings may work better.\n        if kernel_size == 19:\n            self.kernel_sizes = [5, 7, 9, 9, 3, 3, 3]\n            self.dilates = [1, 1, 1, 2, 4, 5, 7]\n        elif kernel_size == 17:\n            self.kernel_sizes = [5, 7, 9, 3, 3, 3]\n            self.dilates = [1, 1, 2, 4, 5, 7]\n        elif kernel_size == 15:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 5, 7]\n        elif kernel_size == 13:\n            self.kernel_sizes = [5, 7, 7, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 11:\n            self.kernel_sizes = [5, 7, 5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4, 5]\n        elif kernel_size == 9:\n            self.kernel_sizes = [5, 7, 5, 3, 3]\n            self.dilates = [1, 1, 2, 3, 4]\n        elif kernel_size == 7:\n            self.kernel_sizes = [5, 3, 3, 3]\n            self.dilates = [1, 1, 2, 3]\n        elif kernel_size == 5:\n            self.kernel_sizes = [3, 3]\n            self.dilates = [1, 2]\n        else:\n            raise ValueError('Dilated Reparam Block requires kernel_size >= 5')\n\n        if not deploy:\n            self.origin_bn = get_bn(channels, use_sync_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__setattr__('dil_conv_k{}_{}'.format(k, r),\n                                 nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,\n                                           padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,\n                                           bias=False))\n                self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))\n\n    def forward(self, x):\n        if not hasattr(self, 'origin_bn'): # deploy mode\n            return self.lk_origin(x)\n        out = self.origin_bn(self.lk_origin(x))\n        for k, r in zip(self.kernel_sizes, self.dilates):\n            conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n            bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n            out = out + bn(conv(x))\n        return out\n\n    def merge_dilated_branches(self):\n        if hasattr(self, 'origin_bn'):\n            origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))\n                bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))\n                branch_k, branch_b = fuse_bn(conv, bn)\n                origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)\n                origin_b += branch_b\n            merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,\n                                    padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,\n                                    attempt_use_lk_impl=self.attempt_use_lk_impl)\n            merged_conv.weight.data = origin_k\n            merged_conv.bias.data = origin_b\n            self.lk_origin = merged_conv\n            self.__delattr__('origin_bn')\n            for k, r in zip(self.kernel_sizes, self.dilates):\n                self.__delattr__('dil_conv_k{}_{}'.format(k, r))\n                self.__delattr__('dil_bn_k{}_{}'.format(k, r))\n       \n\nclass CTXDownsample(nn.Module):\n    def __init__(self, dim, h_dim):\n        super().__init__()\n        \n        self.x_proj = nn.Sequential(\n            nn.Conv2d(dim, h_dim, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim)\n        )\n        self.h_proj = nn.Sequential(\n            nn.Conv2d(h_dim//4, h_dim//4, kernel_size=3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(h_dim//4)\n        )\n\n    def forward(self, x, ctx):\n        x = self.x_proj(x)\n        ctx = self.h_proj(ctx)\n        return (x, ctx)\n\n\nclass ResDWConv(nn.Conv2d):\n    '''\n    Depthwise convolution with residual connection\n    '''\n    def __init__(self, dim, kernel_size=3):\n        super().__init__(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)\n    \n    def forward(self, x):\n        x = x + super().forward(x)\n        return x\n\n\nclass RepConvBlock(nn.Module):\n\n    def __init__(self, \n                 dim=64,\n                 kernel_size=7,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False):\n        super().__init__()\n        \n        self.res_scale = res_scale\n        self.use_checkpoint = use_checkpoint\n        \n        mlp_dim = int(dim*mlp_ratio)\n        \n        self.dwconv = ResDWConv(dim, kernel_size=3)\n    \n        self.proj = nn.Sequential(\n            norm_layer(dim),\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n            SEModule(dim),\n            nn.Conv2d(dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, dim, kernel_size=1),\n            DropPath(drop_path) if drop_path > 0 else nn.Identity(),\n        )\n\n        self.ls = LayerScale(dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        \n    def forward_features(self, x):\n        \n        x = self.dwconv(x)\n        \n        if self.res_scale:\n            x = self.ls(x) + self.proj(x)\n        else:\n            drop_path = self.proj[-1]\n            x = x + drop_path(self.ls(self.proj[:-1](x)))\n\n        return x\n    \n    def forward(self, x):\n        \n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self.forward_features, x, use_reentrant=False)\n        else:\n            x = self.forward_features(x)\n        \n        return x\n\n\nclass DynamicConvBlock(nn.Module):\n    def __init__(self,\n                 dim=64,\n                 ctx_dim=32,\n                 kernel_size=7,\n                 smk_size=5,\n                 num_heads=2,\n                 mlp_ratio=4,\n                 ls_init_value=None,\n                 res_scale=False,\n                 drop_path=0,\n                 norm_layer=LayerNorm2d,\n                 is_first=False,\n                 is_last=False,\n                 use_gemm=False,\n                 deploy=False,\n                 use_checkpoint=False,\n                 **kwargs):\n        \n        super().__init__()\n        \n        ctx_dim = ctx_dim // 4\n        out_dim = dim + ctx_dim\n        mlp_dim = int(dim*mlp_ratio)\n        self.kernel_size = kernel_size\n        self.res_scale = res_scale\n        self.use_gemm = use_gemm\n        self.smk_size = smk_size\n        self.num_heads = num_heads * 2\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n        self.is_first = is_first\n        self.is_last = is_last\n        self.use_checkpoint = use_checkpoint\n\n        if not is_first:\n            self.x_scale = LayerScale(ctx_dim, init_value=1)\n            self.h_scale = LayerScale(ctx_dim, init_value=1)\n        \n        self.dwconv1 = ResDWConv(out_dim, kernel_size=3)\n        self.norm1 = norm_layer(out_dim)\n        \n        self.fusion = nn.Sequential(\n            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, groups=out_dim),\n            nn.BatchNorm2d(out_dim),\n            nn.GELU(),\n            nn.Conv2d(out_dim, dim, kernel_size=1),\n            GRN(dim),\n        )\n        \n        self.weight_query = nn.Sequential(\n            nn.Conv2d(dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n         \n        self.weight_key = nn.Sequential(\n            nn.AdaptiveAvgPool2d(7),\n            nn.Conv2d(ctx_dim, dim//2, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim//2),\n        )\n        \n        self.weight_proj = nn.Conv2d(49, kernel_size**2 + smk_size**2, kernel_size=1)\n        \n        self.dyconv_proj = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.lepe = nn.Sequential(\n            DilatedReparamBlock(dim, kernel_size=kernel_size, deploy=deploy, use_sync_bn=False, attempt_use_lk_impl=use_gemm),\n            nn.BatchNorm2d(dim),\n        )\n        \n        self.se_layer = SEModule(dim)\n        \n        self.gate = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1, bias=False),\n            nn.BatchNorm2d(dim),\n            nn.SiLU(),\n        )\n\n        self.proj = nn.Sequential(\n            nn.BatchNorm2d(dim),\n            nn.Conv2d(dim, out_dim, kernel_size=1),\n        )\n        \n        self.dwconv2 = ResDWConv(out_dim, kernel_size=3)\n        self.norm2 = norm_layer(out_dim)\n        \n        self.mlp = nn.Sequential(\n            nn.Conv2d(out_dim, mlp_dim, kernel_size=1),\n            nn.GELU(),\n            ResDWConv(mlp_dim, kernel_size=3),\n            GRN(mlp_dim),\n            nn.Conv2d(mlp_dim, out_dim, kernel_size=1),\n        )\n        \n        self.ls1 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.ls2 = LayerScale(out_dim, init_value=ls_init_value) if ls_init_value is not None else nn.Identity()\n        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()\n        \n        self.get_rpb()\n\n\n    def get_rpb(self):\n        self.rpb_size1 = 2 * self.smk_size - 1\n        self.rpb1 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size1, self.rpb_size1))\n        self.rpb_size2 = 2 * self.kernel_size - 1\n        self.rpb2 = nn.Parameter(torch.empty(self.num_heads, self.rpb_size2, self.rpb_size2))\n        nn.init.zeros_(self.rpb1)\n        nn.init.zeros_(self.rpb2)\n    \n        \n    @torch.no_grad()\n    def generate_idx(self, kernel_size):\n        rpb_size = 2 * kernel_size - 1\n        idx_h = torch.arange(0, kernel_size)\n        idx_w = torch.arange(0, kernel_size)\n        idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).view(-1)\n        return (idx_h, idx_w, idx_k)\n    \n\n    def apply_rpb(self, attn, rpb, height, width, kernel_size, idx_h, idx_w, idx_k):\n        \"\"\"\n        RPB implementation directly borrowed from https://tinyurl.com/mrbub4t3\n        \"\"\"\n        num_repeat_h = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_w = torch.ones(kernel_size, dtype=torch.long)\n        num_repeat_h[kernel_size//2] = height - (kernel_size-1)\n        num_repeat_w[kernel_size//2] = width - (kernel_size-1)\n        bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*kernel_size-1)) + idx_w.repeat_interleave(num_repeat_w)\n        bias_idx = bias_hw.unsqueeze(-1) + idx_k\n        bias_idx = bias_idx.reshape(-1, int(kernel_size**2))\n        bias_idx = torch.flip(bias_idx, [0])\n        rpb = torch.flatten(rpb, 1, 2)[:, bias_idx]\n        rpb = rpb.reshape(1, int(self.num_heads), int(height), int(width), int(kernel_size**2))\n        return attn + rpb\n    \n\n    def _forward_inner(self, x, h_x, h_r):\n        input_resoltion = x.shape[2:]\n        B, C, H, W = x.shape\n        B, C_h, H_h, W_h = h_x.shape\n        \n        if not self.is_first:\n            h_x = self.x_scale(h_x) + self.h_scale(h_r)\n\n        x_f = torch.cat([x, h_x], dim=1)\n        x_f = self.dwconv1(x_f)\n        identity = x_f\n        x_f = self.norm1(x_f)\n        x = self.fusion(x_f)\n        gate = self.gate(x)\n        lepe = self.lepe(x)\n        \n        is_pad = False\n        if min(H, W) < self.kernel_size:\n            is_pad = True\n            if H < W:\n                size = (self.kernel_size, int(self.kernel_size / H * W))\n            else:\n                size = (int(self.kernel_size / W * H), self.kernel_size)\n                \n            x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)\n            x_f = F.interpolate(x_f, size=size, mode='bilinear', align_corners=False)\n            H, W = size\n\n        query, key = torch.split(x_f, split_size_or_sections=[C, C_h], dim=1)\n        query = self.weight_query(query) * self.scale\n        key = self.weight_key(key)\n        query = rearrange(query, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        key = rearrange(key, 'b (g c) h w -> b g c (h w)', g=self.num_heads)\n        weight = einsum(query, key, 'b g c n, b g c l -> b g n l')\n        weight = rearrange(weight, 'b g n l -> b l g n').contiguous()\n        weight = self.weight_proj(weight)\n        weight = rearrange(weight, 'b l g (h w) -> b g h w l', h=H, w=W)\n\n        attn1, attn2 = torch.split(weight, split_size_or_sections=[self.smk_size**2, self.kernel_size**2], dim=-1)\n        rpb1_idx = self.generate_idx(self.smk_size)\n        rpb2_idx = self.generate_idx(self.kernel_size)\n        attn1 = self.apply_rpb(attn1, self.rpb1, H, W, self.smk_size, *rpb1_idx)\n        attn2 = self.apply_rpb(attn2, self.rpb2, H, W, self.kernel_size, *rpb2_idx)\n        attn1 = torch.softmax(attn1, dim=-1)\n        attn2 = torch.softmax(attn2, dim=-1)\n        value = rearrange(x, 'b (m g c) h w -> m b g h w c', m=2, g=self.num_heads)\n\n        x1 = na2d_av(attn1, value[0], kernel_size=self.smk_size)\n        x2 = na2d_av(attn2, value[1], kernel_size=self.kernel_size)\n\n        x = torch.cat([x1, x2], dim=1)\n        x = rearrange(x, 'b g h w c -> b (g c) h w', h=H, w=W)\n        \n        if is_pad:\n            x = F.adaptive_avg_pool2d(x, input_resoltion)\n        \n        x = self.dyconv_proj(x)\n\n        x = x + lepe\n        x = self.se_layer(x)\n\n        x = gate * x\n        x = self.proj(x)\n\n        if self.res_scale:\n            x = self.ls1(identity) + self.drop_path(x)\n        else:\n            x = identity + self.drop_path(self.ls1(x))\n        \n        x = self.dwconv2(x)\n         \n        if self.res_scale:\n            x = self.ls2(x) + self.drop_path(self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n\n        if self.is_last:\n            return (x, None)\n        else:\n            l_x, h_x = torch.split(x, split_size_or_sections=[C, C_h], dim=1)\n            return (l_x, h_x)\n    \n    def forward(self, x, h_x, h_r):\n        if self.use_checkpoint and x.requires_grad:\n            x = checkpoint(self._forward_inner, x, h_x, h_r, use_reentrant=False)\n        else:\n            x = self._forward_inner(x, h_x, h_r)\n        return x\n\n\nclass OverLoCK(nn.Module):\n    '''\n    An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels\n    https://arxiv.org/abs/2502.20087\n    '''\n    def __init__(self, \n                 depth=[2, 2, 2, 2],\n                 sub_depth=[4, 2],\n                 in_chans=3, \n                 embed_dim=[96, 192, 384, 768],\n                 kernel_size=[7, 7, 7, 7],\n                 mlp_ratio=[4, 4, 4, 4],\n                 sub_mlp_ratio=[4, 4],\n                 sub_num_heads=[4, 8],\n                 ls_init_value=[None, None, 1, 1],\n                 res_scale=True,\n                 smk_size=5,\n                 deploy=False,\n                 use_gemm=True,\n                 use_ds=True,\n                 drop_rate=0,\n                 drop_path_rate=0,\n                 norm_layer=LayerNorm2d,\n                 projection=1024,\n                 num_classes=1000,\n                 use_checkpoint=[0, 0, 0, 0],\n            ):\n \n        super().__init__()\n        \n        fusion_dim = embed_dim[-1] + embed_dim[-1]//4\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim\n\n        self.patch_embed1 = stem(in_chans, embed_dim[0])\n        self.patch_embed2 = downsample(embed_dim[0], embed_dim[1])\n        self.patch_embed3 = downsample(embed_dim[1], embed_dim[2])\n        self.patch_embed4 = downsample(embed_dim[2], embed_dim[3])\n        self.high_level_proj = nn.Conv2d(embed_dim[-1], embed_dim[-1]//4, kernel_size=1)\n        self.patch_embedx = CTXDownsample(embed_dim[2], embed_dim[3])\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth) + sum(sub_depth))]\n\n        self.blocks1 = nn.ModuleList()\n        self.blocks2 = nn.ModuleList()\n        self.blocks3 = nn.ModuleList()\n        self.blocks4 = nn.ModuleList()\n        self.sub_blocks3 = nn.ModuleList()\n        self.sub_blocks4 = nn.ModuleList()\n        \n        for i in range(depth[0]):\n            self.blocks1.append(\n                RepConvBlock(\n                    dim=embed_dim[0],\n                    kernel_size=kernel_size[0],\n                    mlp_ratio=mlp_ratio[0],\n                    ls_init_value=ls_init_value[0],\n                    res_scale=res_scale,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[0]),\n                )\n            )\n        \n        for i in range(depth[1]):\n            self.blocks2.append(\n                RepConvBlock(\n                    dim=embed_dim[1],\n                    kernel_size=kernel_size[1],\n                    mlp_ratio=mlp_ratio[1],\n                    ls_init_value=ls_init_value[1],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+depth[0]],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[1]),\n                )\n            )\n            \n        for i in range(depth[2]):\n            self.blocks3.append(\n                RepConvBlock(\n                    dim=embed_dim[2],\n                    kernel_size=kernel_size[2],\n                    mlp_ratio=mlp_ratio[2],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:2])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n\n        for i in range(depth[3]):\n            self.blocks4.append(\n                RepConvBlock(\n                    dim=embed_dim[3],\n                    kernel_size=kernel_size[3],\n                    mlp_ratio=mlp_ratio[3],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth[:3])],\n                    norm_layer=norm_layer,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n            \n        for i in range(sub_depth[0]):\n            self.sub_blocks3.append(\n                DynamicConvBlock(\n                    dim=embed_dim[2],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[2],\n                    num_heads=sub_num_heads[0],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[0],\n                    ls_init_value=ls_init_value[2],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    is_first=(i==0),\n                    use_checkpoint=(i<use_checkpoint[2]),\n                )\n            )\n        \n        for i in range(sub_depth[1]):\n            self.sub_blocks4.append(\n                DynamicConvBlock(\n                    dim=embed_dim[3],\n                    ctx_dim=embed_dim[-1],\n                    kernel_size=kernel_size[-1],\n                    num_heads=sub_num_heads[1],\n                    pool_size=7,\n                    mlp_ratio=sub_mlp_ratio[1],\n                    ls_init_value=ls_init_value[3],\n                    res_scale=res_scale,\n                    drop_path=dpr[i+sum(depth)+sub_depth[0]],\n                    norm_layer=norm_layer,\n                    smk_size=smk_size,\n                    is_first=False,\n                    is_last=(i==sub_depth[1]-1),\n                    use_gemm=use_gemm,\n                    deploy=deploy,\n                    use_checkpoint=(i<use_checkpoint[3]),\n                )\n            )\n\n        # Aux Cls Head\n        if use_ds:\n            self.aux_head = nn.Sequential(\n                nn.BatchNorm2d(embed_dim[-1]),\n                nn.AdaptiveAvgPool2d(1),\n                nn.Conv2d(embed_dim[-1], num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n            )\n        \n        # Main Cls Head\n        self.head = nn.Sequential(\n            nn.Conv2d(fusion_dim, projection, kernel_size=1, bias=False),\n            nn.BatchNorm2d(projection),\n            nn.SiLU(),\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(projection, num_classes, kernel_size=1) if num_classes > 0 else nn.Identity()\n        )\n        \n        self.extra_norm = nn.ModuleList()\n        for idx in range(4):\n            dim = embed_dim[idx]\n            if idx >= 2:\n                dim = dim + embed_dim[-1]//4\n            self.extra_norm.append(norm_layer(dim))\n        \n        del self.aux_head\n        del self.head\n        \n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):\n            nn.init.trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.BatchNorm1d)):\n            nn.init.constant_(m.weight, 1.0)\n            nn.init.constant_(m.bias, 0)\n    \n    \n    def _convert_sync_batchnorm(self):\n        if torch.distributed.is_initialized():\n            self = nn.SyncBatchNorm.convert_sync_batchnorm(self)\n            \n    def forward_pre_features(self, x):\n        \n        outs = []\n        \n        x = self.patch_embed1(x)\n        for blk in self.blocks1:\n            x = blk(x)\n        \n        outs.append(self.extra_norm[0](x))\n            \n        x = self.patch_embed2(x)\n        for blk in self.blocks2:\n            x = blk(x)\n\n        outs.append(self.extra_norm[1](x))\n        \n        return outs\n    \n    \n    def forward_base_features(self, x):\n        \n        x = self.patch_embed3(x)\n        for blk in self.blocks3:\n            x = blk(x)\n            \n        ctx = self.patch_embed4(x)\n        for blk in self.blocks4:\n            ctx = blk(ctx)\n\n        return (x, ctx)\n    \n\n    def forward_sub_features(self, x, ctx):\n        \n        outs = []\n        \n        # ctx_cls = ctx\n        ctx_ori = self.high_level_proj(ctx)\n        ctx_up = F.interpolate(ctx_ori, size=x.shape[2:], mode='bilinear', align_corners=False)\n        \n        for idx, blk in enumerate(self.sub_blocks3):\n            if idx == 0:\n                ctx = ctx_up\n            x, ctx = blk(x, ctx, ctx_up)\n        \n        outs.append(self.extra_norm[2](torch.cat([x, ctx], dim=1)))\n        \n        x, ctx = self.patch_embedx(x, ctx)\n        for idx, blk in enumerate(self.sub_blocks4):\n            x, ctx = blk(x, ctx, ctx_ori)\n        \n        outs.append(self.extra_norm[3](x))\n        \n        return outs\n\n    def forward_features(self, x):\n        \n        x0, x1 = self.forward_pre_features(x)\n        x, ctx = self.forward_base_features(x1)\n        x2, x3 = self.forward_sub_features(x, ctx)\n\n        return (x0, x1, x2, x3)\n\n    def forward(self, x):\n        \n        x = self.forward_features(x)\n        \n        return x\n\n\n@MODELS.register_module()\ndef overlock_xt(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[2, 2, 3, 2],\n        sub_depth=[6, 2],\n        embed_dim=[56, 112, 256, 336],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 6],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_xt_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_t(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[4, 4, 6, 2],\n        sub_depth=[12, 2],\n        embed_dim=[64, 128, 256, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[4, 8],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_s(pretrained=False, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[6, 6, 8, 3],\n        sub_depth=[16, 3],\n        embed_dim=[64, 128, 320, 512],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[8, 16],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n\n\n@MODELS.register_module()\ndef overlock_b(pretrained=None, pretrained_cfg=None, **kwargs):\n    \n    model = OverLoCK(\n        depth=[8, 8, 10, 4],\n        sub_depth=[20, 4],\n        embed_dim=[80, 160, 384, 576],\n        kernel_size=[17, 15, 13, 7],\n        mlp_ratio=[4, 4, 4, 4],\n        sub_num_heads=[6, 9],\n        sub_mlp_ratio=[3, 3],\n        **kwargs\n    )\n\n    if pretrained:\n        pretrained = 'https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth'\n        logger = get_root_logger()\n        load_checkpoint(model, pretrained, logger=logger)\n    model._convert_sync_batchnorm()\n    return model\n"
  },
  {
    "path": "segmentation/readme.md",
    "content": "# Applying OverLoCK to Semantic Segmentation   \n\n## 1. Requirements\n\n```\npip install mmcv-full==1.7.2 --no-cache-dir\npip install mmsegmentation==0.30.0 --no-cache-dir\n```\n💡 To enable torch>=2.1.0 to support mmcv 1.7.2, you need to make the following changes:  \n> 1️⃣ https://goo.su/XhU5vWr     \n> 2️⃣ https://goo.su/ogm4yO\n\n\n\n## 2. Data Preparation\n\nPrepare ADE20K dataset according to the [guidelines](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md).  \n\n## 3. Main Results on ADE20K using UperNet framework\n\n|    Backbone   |   Pretrain  | Schedule | mIoU |                         Config                          | Download |\n|:-------------:|:-----------:|:--------:|--------|:-------------------------------------------------------:|:----------:|\n| OverLoCK-T | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_t_in1k_224.pth)|    160K    |  50.3     |    [config](configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py)    |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_tiny_ade20k.pth)          |\n| OverLoCK-S | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_s_in1k_224.pth)|    160K    |51.3       |    [config](configs/overlock/upernet_overlock_small_ade20k_8xb2.py)    |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_small_ade20k.pth)           |\n| OverLoCK-B | [ImageNet-1K](https://github.com/LMMMEng/OverLoCK/releases/download/v1/overlock_b_in1k_224.pth) |    160K    |51.7        |    [config](configs/overlock/upernet_overlock_base_ade20k_8xb2.py)    |[model](https://github.com/LMMMEng/OverLoCK/releases/download/v1/upernet_overlock_base_ade20k.pth)           |\n\n## 4. Train\nTo train ``OverLoCK-T + UperNet`` model on ADE20K dataset with 8 gpus (single node), run:\n```\nbash scripts/dist_train.sh configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py 8\n```\n\n## 5. Validation\nTo evaluate ``OverLoCK-T + UperNet`` model on ADE20K dataset, run:\n```\nbash scripts/dist_test.sh configs/overlock/upernet_overlock_tiny_ade20k_8xb2.py path-to-checkpoint 8 --eval mIoU\n```\n\n## Citation\nIf you find this project useful for your research, please consider citing:\n```\n@inproceedings{lou2025overlock,\n  title={OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels},\n  author={Lou, Meng and Yu, Yizhou},\n  booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n  pages={128--138},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "segmentation/scripts/dist_test.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nCHECKPOINT=$2\nGPUS=$3\nPORT=$((RANDOM+10000))\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\ntorchrun --nproc_per_node=$GPUS --master_port=$PORT test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}"
  },
  {
    "path": "segmentation/scripts/dist_train.sh",
    "content": "#!/usr/bin/env bash\nCONFIG=$1\nGPUS=$2\nPORT=$((RANDOM+10000))\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\ntorchrun --nproc_per_node=$GPUS --master_port=$PORT train.py $CONFIG --launcher pytorch ${@:3}"
  },
  {
    "path": "segmentation/test.py",
    "content": "import warnings\nimport argparse\nimport os\n\nimport mmcv\nimport torch\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import (get_dist_info, init_dist, load_checkpoint,\n                         wrap_fp16_model)\nfrom mmcv.utils import DictAction\n\nfrom mmseg.apis import multi_gpu_test, single_gpu_test\nfrom mmseg.datasets import build_dataloader, build_dataset\nfrom mmseg.models import build_segmentor\n\nimport models\nimport mmseg_custom\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='mmseg test (and eval) a model')\n    parser.add_argument('config', help='test config file path')\n    parser.add_argument('checkpoint', help='checkpoint file')\n    parser.add_argument(\n        '--aug-test', action='store_true', help='Use Flip and Multi scale aug')\n    parser.add_argument('--out', help='output result file in pickle format')\n    parser.add_argument(\n        '--format-only',\n        action='store_true',\n        help='Format the output results without perform evaluation. It is'\n        'useful when you want to format the result to a specific format and '\n        'submit it to the test server')\n    parser.add_argument(\n        '--eval',\n        type=str,\n        nargs='+',\n        help='evaluation metrics, which depends on the dataset, e.g., \"mIoU\"'\n        ' for generic datasets, and \"cityscapes\" for Cityscapes')\n    parser.add_argument('--show', action='store_true', help='show results')\n    parser.add_argument(\n        '--show-dir', help='directory where painted images will be saved')\n    parser.add_argument(\n        '--gpu-collect',\n        action='store_true',\n        help='whether to use gpu to collect results.')\n    parser.add_argument(\n        '--tmpdir',\n        help='tmp directory used for collecting results from multiple '\n        'workers, available when gpu_collect is not specified')\n    parser.add_argument(\n        '--options', nargs='+', action=DictAction, help='custom options')\n    parser.add_argument(\n        '--eval-options',\n        nargs='+',\n        action=DictAction,\n        help='custom options for evaluation')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument(\n        '--opacity',\n        type=float,\n        default=0.5,\n        help='Opacity of painted segmentation map. In (0, 1] range.')\n    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    assert args.out or args.eval or args.format_only or args.show \\\n        or args.show_dir, \\\n        ('Please specify at least one operation (save/eval/format/show the '\n         'results / save the results) with the argument \"--out\", \"--eval\"'\n         ', \"--format-only\", \"--show\" or \"--show-dir\"')\n\n    if args.eval and args.format_only:\n        raise ValueError('--eval and --format_only cannot be both specified')\n\n    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):\n        raise ValueError('The output file must be a pkl file.')\n\n    cfg = mmcv.Config.fromfile(args.config)\n    if args.options is not None:\n        cfg.merge_from_dict(args.options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n    if args.aug_test:\n        # hard code index\n        cfg.data.test.pipeline[1].img_ratios = [\n            0.5, 0.75, 1.0, 1.25, 1.5, 1.75\n        ]\n        cfg.data.test.pipeline[1].flip = True\n    cfg.model.pretrained = None\n    cfg.data.test.test_mode = True\n\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n\n    # build the dataloader\n    # TODO: support multiple images per gpu (only minor changes are needed)\n    dataset = build_dataset(cfg.data.test)\n    data_loader = build_dataloader(\n        dataset,\n        samples_per_gpu=1,\n        workers_per_gpu=cfg.data.workers_per_gpu,\n        dist=distributed,\n        shuffle=False)\n\n    # build the model and load checkpoint\n    cfg.model.train_cfg = None\n    model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))\n    fp16_cfg = cfg.get('fp16', None)\n    if fp16_cfg is not None:\n        wrap_fp16_model(model)\n        \n    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')\n    try:\n        model.CLASSES = checkpoint['meta']['CLASSES']\n        model.PALETTE = checkpoint['meta']['PALETTE']\n    except:\n        warnings.warn(\"'CLASSES' and 'PALETTE' are not in the checkpoint.\")\n\n    print(f'Config:\\n{cfg.pretty_text}')\n    \n    efficient_test = False\n    if args.eval_options is not None:\n        efficient_test = args.eval_options.get('efficient_test', False)\n\n    if not distributed:\n        model = MMDataParallel(model, device_ids=[0])\n        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,\n                                  efficient_test, args.opacity)\n    else:\n        model = MMDistributedDataParallel(\n            model.cuda(),\n            device_ids=[torch.cuda.current_device()],\n            broadcast_buffers=False)\n        outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect, efficient_test)\n\n    rank, _ = get_dist_info()\n    if rank == 0:\n        if args.out:\n            print(f'\\nwriting results to {args.out}')\n            mmcv.dump(outputs, args.out)\n        kwargs = {} if args.eval_options is None else args.eval_options\n        if args.format_only:\n            dataset.format_results(outputs, **kwargs)\n        if args.eval:\n            dataset.evaluate(outputs, args.eval, **kwargs)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "segmentation/train.py",
    "content": "import argparse\nimport warnings\nimport copy\nimport os\nimport os.path as osp\nimport time\n\nimport mmcv\nimport torch\nfrom mmcv.runner import get_dist_info, init_dist\nfrom mmcv.utils import Config, DictAction, get_git_hash\n\nfrom mmseg import __version__\nfrom mmseg.apis import set_random_seed, train_segmentor\nfrom mmseg.datasets import build_dataset\nfrom mmseg.models import build_segmentor\nfrom mmseg.utils import collect_env, get_root_logger\n\nimport models\nimport mmseg_custom\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a segmentor')\n    parser.add_argument('config', help='train config file path')\n    parser.add_argument('--work-dir', help='the dir to save logs and models')\n    parser.add_argument(\n        '--load-from', help='the checkpoint file to load weights from')\n    parser.add_argument(\n        '--resume-from', help='the checkpoint file to resume from')\n    parser.add_argument(\n        '--no-validate',\n        action='store_true',\n        help='whether not to evaluate the checkpoint during training')\n    group_gpus = parser.add_mutually_exclusive_group()\n    group_gpus.add_argument(\n        '--gpus',\n        type=int,\n        help='number of gpus to use '\n        '(only applicable to non-distributed training)')\n    group_gpus.add_argument(\n        '--gpu-ids',\n        type=int,\n        nargs='+',\n        help='ids of gpus to use '\n        '(only applicable to non-distributed training)')\n    parser.add_argument('--seed', type=int, default=None, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')\n    parser.add_argument(\n        '--options', nargs='+', action=DictAction, help='custom options')\n    parser.add_argument(\n        '--drop-path',\n        default=-1,\n        type=float,\n        help='drop-path-rate of the backbone network')\n    parser.add_argument(\n        '--freeze-bn',\n        action='store_true',\n        default=False,\n        help='freeze the BN layer of the backbone model during training')\n    parser.add_argument(\n        '--amp',\n        action='store_true',\n        default=False,\n        help='mixed precision training')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    \n    cfg = Config.fromfile(args.config)\n    if args.options is not None:\n        cfg.merge_from_dict(args.options)\n    # set cudnn_benchmark\n    if cfg.get('cudnn_benchmark', False):\n        torch.backends.cudnn.benchmark = True\n\n    # work_dir is determined in this priority: CLI > segment in file > filename\n    if args.work_dir is not None:\n        # update configs according to CLI args if args.work_dir is not None\n        cfg.work_dir = args.work_dir\n    elif cfg.get('work_dir', None) is None:\n        # use config filename as default work_dir if cfg.work_dir is None\n        cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])\n    \n    if args.load_from is not None:\n        cfg.load_from = args.load_from\n    if args.resume_from is not None:\n        cfg.resume_from = args.resume_from\n    if args.gpu_ids is not None:\n        cfg.gpu_ids = args.gpu_ids\n    else:\n        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)\n    \n    if args.drop_path >= 0:\n        try:\n            cfg.model.backbone.drop_path_rate = args.drop_path\n        except:\n            logger.info('drop_path is not defined in the config file')    \n    \n    if args.freeze_bn:\n        try:\n            cfg.model.backbone.freeze_bn = True\n        except:\n            logger.info('freeze_bn is not defined in the config file')    \n    \n    if args.amp:\n        loss_scale = 'dynamic'\n        if cfg.get('fp16', None) is None:\n            cfg.fp16 = dict()\n        else:\n            warnings.warn('fp16 has been defined in the config file')\n        cfg.optimizer_config.type = 'Fp16OptimizerHook'\n        cfg.optimizer_config.loss_scale = loss_scale\n    \n    # init distributed env first, since logger depends on the dist info.\n    cfg.device = 'cuda'\n    if args.launcher == 'none':\n        distributed = False\n    else:\n        distributed = True\n        init_dist(args.launcher, **cfg.dist_params)\n    \n    # create work_dir\n    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n    # dump config\n    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))\n    # init the logger before other steps\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)\n    \n    # init the meta dict to record some important information such as\n    # environment info and seed, which will be logged\n    meta = dict()\n    # log env info\n    env_info_dict = collect_env()\n    env_info = '\\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])\n    dash_line = '-' * 60 + '\\n'\n    logger.info('Environment info:\\n' + dash_line + env_info + '\\n' +\n                dash_line)\n    meta['env_info'] = env_info\n\n    # log some basic info\n    logger.info(f'Distributed training: {distributed}')\n    logger.info(f'Config:\\n{cfg.pretty_text}')\n\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, deterministic: '\n                    f'{args.deterministic}')\n        set_random_seed(args.seed, deterministic=args.deterministic)\n    cfg.seed = args.seed\n    meta['seed'] = args.seed\n    meta['exp_name'] = osp.basename(args.config)\n\n    model = build_segmentor(\n        cfg.model,\n        train_cfg=cfg.get('train_cfg'),\n        test_cfg=cfg.get('test_cfg'))\n\n    logger.info(model)\n\n    datasets = [build_dataset(cfg.data.train)]\n    if len(cfg.workflow) == 2:\n        val_dataset = copy.deepcopy(cfg.data.val)\n        val_dataset.pipeline = cfg.data.train.pipeline\n        datasets.append(build_dataset(val_dataset))\n    if cfg.checkpoint_config is not None:\n        # save mmseg version, config file content and class names in\n        # checkpoints as meta data\n        cfg.checkpoint_config.meta = dict(\n            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',\n            config=cfg.pretty_text,\n            CLASSES=datasets[0].CLASSES,\n            PALETTE=datasets[0].PALETTE)\n    # add an attribute for visualization convenience\n    model.CLASSES = datasets[0].CLASSES\n    train_segmentor(\n        model,\n        datasets,\n        cfg,\n        distributed=distributed,\n        validate=(not args.no_validate),\n        timestamp=timestamp,\n        meta=meta)\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\n\"\"\" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet\ntraining results with some of the latest networks and training techniques. It favours canonical PyTorch\nand standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed\nand training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.\n\nThis script was started from an early version of the PyTorch ImageNet example\n(https://github.com/pytorch/examples/tree/master/imagenet)\n\nNVIDIA CUDA specific speedups adopted from NVIDIA Apex examples\n(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\nHacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)\n\"\"\"\nimport argparse\nimport datetime\nimport json\nimport logging\nimport os\nimport time\nfrom collections import OrderedDict\nfrom contextlib import suppress\n\nimport torch\nimport torchvision\nimport yaml\nfrom timm.data import (AugMixDataset, FastCollateMixup, Mixup, create_dataset,\n                       create_loader, resolve_data_config)\nfrom timm.loss import *\nfrom timm.models import (convert_splitbn_model, create_model, load_checkpoint,\n                         model_parameters, resume_checkpoint, safe_model_name)\nfrom timm.optim import create_optimizer_v2, optimizer_kwargs\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import *\nfrom timm.utils import ApexScaler, NativeScaler\nfrom torch import nn\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\ntry:\n    from mmcv.runner import load_checkpoint as load_ckpt\nexcept:\n    from mmengine.runner import load_checkpoint as load_ckpt\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\ntry:\n    import wandb\n    has_wandb = True\nexcept ImportError:\n    has_wandb = False\n\nimport models\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n\ndef get_args_parser():\n    # The first arg parser parses out only the --config argument, this argument is used to\n    # load a yaml file containing key-values that override the defaults for the main parser below\n    # config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\n    # parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments')\n    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training', add_help=False)\n\n    # Dataset parameters\n    parser.add_argument('--data-dir', metavar='DIR', help='path to dataset')\n    parser.add_argument('--dataset', '-d', metavar='NAME', default='',\n                        help='dataset type (default: ImageFolder/ImageTar if empty)')\n    parser.add_argument('--train-split', metavar='NAME', default='train',\n                        help='dataset train split (default: train)')\n    parser.add_argument('--val-split', metavar='NAME', default='validation',\n                        help='dataset validation split (default: validation)')\n    parser.add_argument('--dataset-download', action='store_true', default=False,\n                        help='Allow download of dataset for torch/ and tfds/ datasets that support it.')\n    parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',\n                        help='path to class to idx mapping file (default: \"\")')\n\n    # Model parameters\n    parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',\n                        help='Name of model to train (default: \"resnet50\"')\n    parser.add_argument('--pretrained', action='store_true', default=False,\n                        help='Start with pretrained version of specified network (if avail)')\n    parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                        help='Initialize model from this checkpoint (default: none)')\n    parser.add_argument('--resume', default=None, type=str, metavar='PATH',\n                        help='Resume full model and optimizer state from checkpoint (default: none)')\n    parser.add_argument('--finetune', default=None, type=str, metavar='PATH',\n                        help='Fine-tune model from this checkpoint (default: none)')\n    parser.add_argument('--no-resume-opt', action='store_true', default=False,\n                        help='prevent resume of optimizer state when resuming model')\n    parser.add_argument('--num-classes', type=int, default=1000, metavar='N',\n                        help='number of label classes (Model default if None)')\n    parser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                        help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n    parser.add_argument('--img-size', type=int, default=None, metavar='N',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--input-size', default=None, nargs=3, type=int,\n                        metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')\n    parser.add_argument('--crop-pct', default=None, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                        help='Override mean pixel value of dataset')\n    parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                        help='Override std deviation of of dataset')\n    parser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                        help='input batch size for training (default: 128)')\n    parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',\n                        help='validation batch size override (default: None)')\n    parser.add_argument('--grad-checkpoint', action='store_true', default=False,\n                        help='Using gradient checkpointing for saving GPU memory')\n    parser.add_argument('--ckpt-stg', default=[0, 0, 0, 0], type=int, nargs='+',\n                        help='stage for using grad checkpoint')\n    parser.add_argument('--val-freq', default=1, type=int,\n                        help='do evaluation per n epoch')\n    parser.add_argument('--val-start-epoch', default=0, type=int,\n                        help='do evaluation per epoch after n-th epoch')\n    # Optimizer parameters\n    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                        help='Optimizer (default: \"adamw\"')\n    parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',\n                        help='Optimizer Epsilon (default: None, use opt default)')\n    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                        help='Optimizer Betas (default: None, use opt default)')\n    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                        help='Optimizer momentum (default: 0.9)')\n    parser.add_argument('--weight-decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--clip-mode', type=str, default='norm',\n                        help='Gradient clipping mode. One of (\"norm\", \"value\", \"agc\")')\n    parser.add_argument('--aux-loss-ratio', type=float, default=0.4,\n                        help='Aux loss weight')\n\n    # Learning rate schedule parameters\n    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                        help='LR scheduler (default: \"step\"')\n    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',\n                        help='learning rate (default: 1e-3)')\n    parser.add_argument('--auto-lr', action='store_true',\n                        default=False, help='auto scaling learning rate')\n    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                        help='learning rate noise on/off epoch percentages')\n    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                        help='learning rate noise limit percent (default: 0.67)')\n    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                        help='learning rate noise std-dev (default: 1.0)')\n    parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                        help='learning rate cycle len multiplier (default: 1.0)')\n    parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',\n                        help='amount to decay each learning rate cycle (default: 0.5)')\n    parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                        help='learning rate cycle limit, cycles enabled if > 1')\n    parser.add_argument('--lr-k-decay', type=float, default=1.0,\n                        help='learning rate k-decay for cosine/poly (default: 1.0)')\n    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                        help='warmup learning rate (default: 1e-6)')\n    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n    parser.add_argument('--epochs', type=int, default=300, metavar='N',\n                        help='number of epochs to train (default: 300)')\n    parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',\n                        help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')\n    parser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                        help='manual epoch number (useful on restarts)')\n    parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',\n                        help='epoch interval to decay LR')\n    parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\n    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                        help='patience epochs for Plateau LR scheduler (default: 10')\n    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                        help='LR decay rate (default: 0.1)')\n\n    # Augmentation & regularization parameters\n    parser.add_argument('--no-aug', action='store_true', default=False,\n                        help='Disable all training augmentation, override other train aug args')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=\"rand-m9-mstd0.5-inc1\", metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: rand-m9-mstd0.5-inc1)'),\n    parser.add_argument('--aug-repeats', type=int, default=0,\n                        help='Number of augmentation repetitions (distributed training only) (default: 0)')\n    parser.add_argument('--aug-splits', type=int, default=0,\n                        help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\n    parser.add_argument('--jsd-loss', action='store_true', default=False,\n                        help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\n    parser.add_argument('--bce-loss', action='store_true', default=False,\n                        help='Enable BCE loss w/ Mixup/CutMix use.')\n    parser.add_argument('--bce-target-thresh', type=float, default=None,\n                        help='Threshold for binarizing softened BCE targets (default: None, disabled)')\n    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                        help='Random erase prob (default: 0.25)')\n    parser.add_argument('--remode', type=str, default='pixel',\n                        help='Random erase mode (default: \"pixel\")')\n    parser.add_argument('--recount', type=int, default=1,\n                        help='Random erase count (default: 1)')\n    parser.add_argument('--resplit', action='store_true', default=False,\n                        help='Do not random erase first (clean) augmentation split')\n    parser.add_argument('--mixup', type=float, default=0.8,\n                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')\n    parser.add_argument('--cutmix', type=float, default=1.0,\n                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')\n    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\n    parser.add_argument('--mixup-prob', type=float, default=1.0,\n                        help='Probability of performing mixup or cutmix when either/both is enabled')\n    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                        help='Probability of switching to cutmix when both mixup and cutmix enabled')\n    parser.add_argument('--mixup-mode', type=str, default='batch',\n                        help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\n    parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                        help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--train-interpolation', type=str, default='random',\n                        help='Training interpolation (random, bilinear, bicubic default: \"random\")')\n    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                        help='Dropout rate (default: 0.)')\n    parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                        help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\n    parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',\n                        help='Drop path rate (default: None)')\n    parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                        help='Drop block rate (default: None)')\n\n    # Batch norm parameters (only works with gen_efficientnet based models currently)\n    parser.add_argument('--bn-tf', action='store_true', default=False,\n                        help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\n    parser.add_argument('--bn-momentum', type=float, default=None,\n                        help='BatchNorm momentum override (if not None)')\n    parser.add_argument('--bn-eps', type=float, default=None,\n                        help='BatchNorm epsilon override (if not None)')\n    parser.add_argument('--no-sync-bn', action='store_true', default=False,\n                        help='Disable NVIDIA Apex or Torch synchronized BatchNorm.')\n    parser.add_argument('--dist-bn', type=str, default='reduce',\n                        help='Distribute BatchNorm stats between nodes after each epoch (\"broadcast\", \"reduce\", or \"\")')\n    parser.add_argument('--split-bn', action='store_true',\n                        help='Enable separate BN layers per augmentation split.')\n\n    # Model Exponential Moving Average\n    parser.add_argument('--model-ema', action='store_true', default=False,\n                        help='Enable tracking moving average of model weights')\n    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                        help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\n    parser.add_argument('--model-ema-decay', type=float, default=0.99984,\n                        help='decay factor for model weights moving average (default: 0.99996)')\n\n    # Misc\n    parser.add_argument('--seed', type=int, default=42, metavar='S',\n                        help='random seed (default: 42)')\n    parser.add_argument('--worker-seeding', type=str, default='all',\n                        help='worker seed mode (default: all)')\n    parser.add_argument('--log-interval', type=int, default=100, metavar='N',\n                        help='how many batches to wait before logging training status')\n    parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                        help='how many batches to wait before writing recovery checkpoint')\n    parser.add_argument('--checkpoint-hist', type=int, default=5, metavar='N',\n                        help='number of checkpoints to keep')\n    parser.add_argument('--checkpoint-freq', type=int, default=50, metavar='N',\n                        help='freq of saving checkpoints')\n    parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                        help='how many training processes to use (default: 8)')\n    parser.add_argument('--save-images', action='store_true', default=False,\n                        help='save images of input bathes every log interval for debugging')\n    parser.add_argument('--amp', action='store_true', default=False,\n                        help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\n    parser.add_argument('--apex-amp', action='store_true', default=False,\n                        help='Use NVIDIA Apex AMP mixed precision')\n    parser.add_argument('--native-amp', action='store_true', default=False,\n                        help='Use Native Torch AMP mixed precision')\n    parser.add_argument('--compile', action='store_true', default=False,\n                        help='Use torch.compile to accelerate training')\n    parser.add_argument('--debug-loss', action='store_true', default=False,\n                        help='Use Anomaly Detection')\n    parser.add_argument('--no-ddp-bb', action='store_true', default=False,\n                        help='Force broadcast buffers for native DDP to off.')\n    parser.add_argument('--channels-last', action='store_true', default=False,\n                        help='Use channels_last memory layout')\n    parser.add_argument('--no-pin-mem', action='store_true', default=False,\n                        help='Disable Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no-prefetcher', action='store_true', default=False,\n                        help='disable fast prefetcher')\n    parser.add_argument('--output', default='', type=str, metavar='PATH',\n                        help='path to output folder (default: none, current dir)')\n    parser.add_argument('--experiment', default='', type=str, metavar='NAME',\n                        help='name of train experiment, name of sub-folder for output')\n    parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                        help='Best metric (default: \"top1\"')\n    parser.add_argument('--tta', type=int, default=0, metavar='N',\n                        help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\n    # parser.add_argument(\"--local_rank\", default=0, type=int)\n    parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                        help='use the multi-epochs-loader to save time at the beginning of every epoch')\n    parser.add_argument('--torchscript', dest='torchscript', action='store_true',\n                        help='convert model torchscript for inference')\n    parser.add_argument('--log-wandb', action='store_true', default=False,\n                        help='log training and validation metrics to wandb')\n    parser.add_argument('--gpu-limit', default=1, type=float)\n    parser.add_argument('--local-rank', default=0, type=int)\n    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')\n    \n    return parser\n\n\ndef main(args):\n\n    # Cache the args as a text string to save them in the output dir later\n    args.device_name = torch.cuda.get_device_name()\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    \n    setup_default_logging()\n    # args, args_text = _parse_args()\n    if args.local_rank == 0:\n        _logger.info(args_text)\n    if args.log_wandb:\n        if has_wandb:\n            wandb.init(project=args.experiment, config=args)\n        else:\n            _logger.warning(\"You've requested to log metrics to wandb but package not found. \"\n                            \"Metrics not being logged to wandb, try `pip install wandb`\")\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    # args.device = 'cuda:0'\n    # args.world_size = 1\n    # args.rank = 0  # global rank\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.distributed = True\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n        # args.device = 'cuda:%d' % args.local_rank\n        \n    elif 'SLURM_PROCID' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.distributed = True\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n        \n    if args.distributed:\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url,\n                                             world_size=args.world_size, rank=args.rank)\n        # args.world_size = torch.distributed.get_world_size()\n        # args.rank = torch.distributed.get_rank()\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        args.device = 'cuda'\n        args.world_size = 1\n        args.rank = 0  # global rank\n        _logger.info('Training with a single process on 1 GPUs.')\n    assert args.rank >= 0\n\n    if args.gpu_limit < 1:\n        torch.cuda.set_per_process_memory_fraction(args.gpu_limit)\n    \n    # resolve AMP arguments based on PyTorch / Apex availability\n    use_amp = None\n    if args.amp:\n        # `--amp` chooses native amp before apex (APEX ver not actively maintained)\n        if has_native_amp:\n            args.native_amp = True\n        elif has_apex:\n            args.apex_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native PyTorch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    random_seed(args.seed, args.rank)\n\n    \n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        # img_size=args.input_size,\n        use_checkpoint=args.ckpt_stg if args.grad_checkpoint else [0] * 4,\n    )\n    \n    if args.finetune:\n        load_ckpt(model, args.finetune)\n    \n    if args.num_classes is None:\n        assert hasattr(\n            model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'\n        # FIXME handle model default vs config num_classes more elegantly\n        args.num_classes = model.num_classes\n\n    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)\n\n    # setup augmentation batch splits for contrastive loss or split bn\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    # enable split bn (separate bn stats per batch-portion)\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    if args.local_rank == 0:\n        _logger.info(model)\n        _logger.info(\n            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')\n        model_str = str(model)\n           \n    # move model to GPU, enable channels last layout if set\n    model.cuda()\n    if args.channels_last:\n        model = model.to(memory_format=torch.channels_last)\n\n    # setup synchronized BatchNorm for distributed training\n    if args.distributed and not args.no_sync_bn:\n        assert not args.split_bn\n        if has_apex and use_amp == 'apex':\n            # Apex SyncBN preferred unless native amp is activated\n            model = convert_syncbn_model(model)\n        else:\n            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n        if args.local_rank == 0:\n            _logger.info(\n                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n\n    if args.torchscript:\n        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'\n        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'\n        model = torch.jit.script(model)\n\n    if args.auto_lr:\n        args.lr = (args.batch_size * args.world_size / 1024) * 1e-3\n        args_text = yaml.load(args_text, Loader=yaml.FullLoader)\n        args_text['lr'] = args.lr\n        args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    if args.local_rank == 0:\n        _logger.info(f'Initail learning rate: {args.lr}.')\n    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))\n\n    # setup automatic mixed-precision (AMP) loss scaling and op casting\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info(\n                'Using native PyTorch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    if args.grad_checkpoint and args.local_rank == 0:\n        _logger.info(\n            f'Using gradient checkpointing, checkpoint stage: {args.ckpt_stg}.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume:\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n\n    # setup exponential moving average of model weights, SWA could be used here too\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        # model_ema_decay = args.model_ema_decay ** (args.batch_size * args.world_size / 512.0)\n        model_ema = ModelEmaV2(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)\n        if args.resume:\n            load_checkpoint(model_ema.module, args.resume, use_ema=True)\n\n    # setup distributed training\n    if args.distributed:\n        if has_apex and use_amp == 'apex':\n            # Apex DDP preferred unless native amp is activated\n            model = ApexDDP(model, delay_allreduce=True)\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n        else:\n            model = NativeDDP(model, device_ids=[args.gpu], broadcast_buffers=not args.no_ddp_bb)\n            if args.local_rank == 0:                \n                _logger.info(\"Using native PyTorch DistributedDataParallel.\")\n        # NOTE: EMA model does not need to be wrapped by DDP\n\n    # setup learning rate schedule and starting epoch\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n\n    _logger.info(\"Creating Dataset ...\")\n    # create the train and eval datasets\n    dataset_train = create_dataset(\n        args.dataset, root=args.data_dir, split=args.train_split, is_training=True,\n        class_map=args.class_map,\n        download=args.dataset_download,\n        batch_size=args.batch_size,\n        repeats=args.epoch_repeats)\n    dataset_eval = create_dataset(\n        args.dataset, root=args.data_dir, split=args.val_split, is_training=False,\n        class_map=args.class_map,\n        download=args.dataset_download,\n        batch_size=args.batch_size)\n\n    # setup mixup / cutmix\n    collate_fn = None\n    mixup_fn = None\n    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    if mixup_active:\n        mixup_args = dict(\n            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n            label_smoothing=args.smoothing, num_classes=args.num_classes)\n        if args.prefetcher:\n            # collate conflict (need to support deinterleaving in collate mixup)\n            assert not num_aug_splits\n            collate_fn = FastCollateMixup(**mixup_args)\n        else:\n            mixup_fn = Mixup(**mixup_args)\n\n    # wrap dataset in AugMix helper\n    if num_aug_splits > 1:\n        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)\n\n    # create data loaders w/ augmentation pipeiine\n    train_interpolation = args.train_interpolation\n    if args.no_aug or not train_interpolation:\n        train_interpolation = data_config['interpolation']\n        \n\n    _logger.info(\"Creating Dataloader ...\")\n    loader_train = create_loader(\n        dataset_train,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        is_training=True,\n        use_prefetcher=args.prefetcher,\n        no_aug=args.no_aug,\n        re_prob=args.reprob,\n        re_mode=args.remode,\n        re_count=args.recount,\n        re_split=args.resplit,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        num_aug_repeats=args.aug_repeats,\n        num_aug_splits=num_aug_splits,\n        interpolation=train_interpolation,\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        collate_fn=collate_fn,\n        pin_memory=not args.no_pin_mem,\n        use_multi_epochs_loader=args.use_multi_epochs_loader,\n        worker_seeding=args.worker_seeding,\n    )\n\n    loader_eval = create_loader(\n        dataset_eval,\n        input_size=data_config['input_size'],\n        batch_size=args.validation_batch_size or args.batch_size,\n        is_training=False,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        crop_pct=data_config['crop_pct'],\n        pin_memory=not args.no_pin_mem,\n    )\n\n    \n    # setup loss function\n    if args.jsd_loss:\n        assert num_aug_splits > 1  # JSD only valid with aug splits set\n        train_loss_fn = JsdCrossEntropy(\n            num_splits=num_aug_splits, smoothing=args.smoothing)\n    elif mixup_active:\n        # smoothing is handled with mixup target transform which outputs sparse, soft targets\n        if args.bce_loss:\n            train_loss_fn = BinaryCrossEntropy(\n                target_threshold=args.bce_target_thresh)\n        else:\n            train_loss_fn = SoftTargetCrossEntropy()\n    elif args.smoothing:\n        if args.bce_loss:\n            train_loss_fn = BinaryCrossEntropy(\n                smoothing=args.smoothing, target_threshold=args.bce_target_thresh)\n        else:\n            # train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)\n            train_loss_fn = nn.CrossEntropyLoss(label_smoothing=args.smoothing)\n    else:\n        train_loss_fn = nn.CrossEntropyLoss()\n    train_loss_fn = train_loss_fn.cuda()\n    validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    # setup checkpoint saver and eval metric tracking\n    eval_metric = args.eval_metric\n    if args.model_ema:\n        ema_save_tag = 'ema'\n        eval_metric_ema = f'{args.eval_metric}_{ema_save_tag}'\n    \n    best_metric = None\n    best_metric_ema = None\n    best_epoch = None\n    best_epoch_ema = None\n    saver = None\n    saver_ema = None\n    output_dir = None\n    resume_mode = args.resume\n    \n    if args.rank == 0:\n        if args.experiment:\n            exp_name = args.experiment\n        else:\n            exp_name = '-'.join([\n                datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n                safe_model_name(args.model),\n                str(data_config['input_size'][-1])\n            ])\n        output_dir = get_outdir(\n            args.output if args.output else './output/train', exp_name)\n        decreasing = True if eval_metric == 'loss' else False\n        \n        checkpoint_dir = f'{output_dir}/checkpoints/'\n        os.makedirs(checkpoint_dir, exist_ok=True)\n        saver = CheckpointSaver(model=model, optimizer=optimizer, \n                                args=args, model_ema=model_ema, \n                                amp_scaler=loss_scaler, checkpoint_dir=checkpoint_dir, \n                                recovery_dir=output_dir, decreasing=decreasing, \n                                max_history=args.checkpoint_hist)\n        \n        \n        if model_ema is not None:\n            checkpoint_dir = f'{output_dir}/checkpoints_ema/'\n            os.makedirs(checkpoint_dir, exist_ok=True)\n            saver_ema = CheckpointSaver(model=model, optimizer=optimizer, \n                                        args=args, model_ema=model_ema,\n                                        checkpoint_prefix='checkpoint-ema',\n                                        amp_scaler=loss_scaler, checkpoint_dir=checkpoint_dir, \n                                        recovery_dir=output_dir, decreasing=decreasing, \n                                        max_history=args.checkpoint_hist)\n\n        \n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n        with open(os.path.join(output_dir, 'model.log'), 'w') as f:\n            f.write(model_str)\n\n        if args.compile:\n            _logger.info(\"Compiling model ...\")\n            model = torch.compile(model)\n    \n        _logger.info(\"Start Training.\")\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n            \n    try:        \n        for epoch in range(start_epoch, num_epochs):\n            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_one_epoch(\n                epoch, model, loader_train, \n                optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, \n                saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, \n                loss_scaler=loss_scaler, \n                model_ema=model_ema, \n                mixup_fn=mixup_fn,\n                eta_meter=AverageMeter())\n\n            if (epoch % args.val_freq == 0) or epoch > args.val_start_epoch or resume_mode:\n                resume_mode = False\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    if args.local_rank == 0:\n                        _logger.info(\"Distributing BatchNorm running means and vars\")\n                    distribute_bn(model, args.world_size,\n                                  args.dist_bn == 'reduce')\n\n                eval_metrics = validate(model, loader_eval, validate_loss_fn, \n                                        args, amp_autocast=amp_autocast)\n\n                if model_ema is not None and not args.model_ema_force_cpu:\n                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                        distribute_bn(model_ema, args.world_size,\n                                      args.dist_bn == 'reduce')\n                    ema_eval_metrics = validate(model_ema.module, loader_eval,\n                                                validate_loss_fn, args,\n                                                amp_autocast=amp_autocast,\n                                                log_suffix=f'_{ema_save_tag}')\n                    eval_metrics.update(ema_eval_metrics)\n            else:\n                eval_metrics = {key: float('-inf') for key in eval_metrics}\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            if output_dir is not None:\n                update_summary(epoch,\n                               train_metrics,\n                               eval_metrics,\n                               os.path.join(output_dir, 'summary.csv'),\n                               write_header=best_metric is None,\n                               log_wandb=args.log_wandb and has_wandb)\n\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n            if saver_ema is not None:\n                # save proper checkpoint with eval metric (ema)\n                save_metric = eval_metrics[eval_metric_ema]\n                best_metric_ema, best_epoch_ema = saver_ema.save_checkpoint(epoch, metric=save_metric)\n\n            \n            if best_metric is not None:\n                \n                if args.local_rank == 0:\n                    _logger.info(f\"Currently Best Accuracy: {best_metric:.2f} at Epoch {best_epoch}\")\n                    if best_metric_ema is not None:\n                        _logger.info(f\"Currently Best Accuracy (EMA): {best_metric_ema:.2f} at Epoch {best_epoch_ema}\")\n                    _logger.info('\\n')\n                    \n                with open(os.path.join(output_dir, 'best-metric.json'), 'w') as f:\n                    \n                    best_metric_info = {\n                        eval_metric: best_metric,\n                        'epoch': best_epoch,\n                    }\n                    \n                    if best_metric_ema is not None:\n                        best_metric_info[eval_metric_ema] = best_metric_ema\n                        best_metric_info['epoch_ema'] = best_epoch_ema\n                        \n                    json.dump(best_metric_info, f, indent=4)\n\n    except KeyboardInterrupt:\n        pass\n    \n    if best_metric is not None:\n        # result_info = f'*** Best metric ({eval_metric}): {best_metric} (epoch {best_epoch}) ***, \n        #                 *** Best metric (ema) ({eval_metric_ema}): {best_metric_ema} (epoch {best_epoch_ema}) ***'\n        \n        best_metric_info = {\n            eval_metric: best_metric,\n            'epoch': best_epoch,\n        }\n        \n        if best_metric_ema is not None:\n            best_metric_info[eval_metric_ema] = best_metric_ema\n            best_metric_info['epoch_ema'] = best_epoch_ema        \n        \n        # _logger.info(f'best_metric: {best_metric_info}')\n        print(f'best_metric: {best_metric_info}')\n        \n\ndef train_one_epoch(epoch, model, loader, optimizer, loss_fn, args,\n                    lr_scheduler=None, saver=None, output_dir=None,\n                    amp_autocast=suppress, loss_scaler=None,\n                    model_ema=None, mixup_fn=None, eta_meter=AverageMeter()):\n\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    losses_aux_m = AverageMeter()\n    is_ds = False\n\n    model.train()\n    # torch.cuda.empty_cache()\n\n    end = time.time()\n    end_epoch = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    \n    for batch_idx, (input, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher:\n            input = input.cuda()\n            target = target.cuda()\n            if mixup_fn is not None:\n                input, target = mixup_fn(input, target)\n        if args.channels_last:\n            input = input.contiguous(memory_format=torch.channels_last)\n\n        if args.debug_loss:\n            detect_anomaly = torch.autograd.detect_anomaly\n        else:\n            detect_anomaly = suppress\n        \n        with amp_autocast():\n            with detect_anomaly():\n               output = model(input)\n               if isinstance(output, (tuple, list, dict)):\n                   is_ds = True\n                   output_main = output['main']\n                   output_aux = output['aux']\n                   loss_main = loss_fn(output_main, target)\n                   loss_aux = loss_fn(output_aux, target)\n                   loss = loss_main + args.aux_loss_ratio * loss_aux\n               else:\n                   loss = loss_fn(output, target) \n\n        if not args.distributed:\n            \n            if is_ds:\n                losses_m.update(loss_main.item(), input.size(0))\n                losses_aux_m.update(loss_aux.item(), input.size(0))\n            else:\n                losses_m.update(loss.item(), input.size(0))\n            \n        optimizer.zero_grad()\n        \n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer,\n                clip_grad=args.clip_grad, clip_mode=args.clip_mode,\n                parameters=model_parameters(\n                    model, exclude_head='agc' in args.clip_mode),\n                create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.clip_grad is not None:\n                dispatch_clip_grad(\n                    model_parameters(\n                        model, exclude_head='agc' in args.clip_mode),\n                    value=args.clip_grad, mode=args.clip_mode)\n            optimizer.step()\n        \n        if model_ema is not None:\n            model_ema.update(model)\n\n        torch.cuda.synchronize()\n        num_updates += 1\n        batch_time_m.update(time.time() - end)\n        \n        if eta_meter is not None:\n            eta_meter.update(time.time() - end)\n            \n        if last_batch or batch_idx % args.log_interval == 0:\n            \n            if args.distributed:\n                \n                if is_ds:\n                    reduced_loss = reduce_tensor(loss_main.data, args.world_size)\n                    reduced_loss_aux = reduce_tensor(loss_aux.data, args.world_size)\n                    losses_m.update(reduced_loss.item(), input.size(0))\n                    losses_aux_m.update(reduced_loss_aux.item(), input.size(0))\n                else:\n                    reduced_loss = reduce_tensor(loss.data, args.world_size)\n                    losses_m.update(reduced_loss.item(), input.size(0))\n                \n                \n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n            \n            eta_remaining_batches = (args.epochs + args.cooldown_epochs - epoch - 1) * len(loader) + len(loader) - batch_idx - 1\n            batch_time_avg = eta_meter.avg if eta_meter is not None else batch_time_m.avg\n            eta_remaining_time = eta_remaining_batches * batch_time_avg\n            eta_td = datetime.timedelta(seconds=eta_remaining_time)\n            eta_str = str(eta_td)\n            current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))\n            gpu_memory = torch.cuda.max_memory_allocated()\n            \n            if args.local_rank == 0:\n                \n                base_info = (\n                    '[{}]  Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'LR: {lr:.3e}  '\n                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '\n                )\n\n                aux_info = (\n                    'Loss (Aux): {loss_aux.val:#.4g} ({loss_aux.avg:#.3g})  '\n                )\n\n                other_info = (\n                    'ETA: {eta}  '\n                    'Memory: {gpu_memory:d}  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})  '\n                )\n\n                if is_ds:\n                    base_info += aux_info\n\n                base_info += other_info\n\n                _logger.info(\n                    base_info.format(\n                        current_time,\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        loss_aux=losses_aux_m if is_ds else None,\n                        batch_time=batch_time_m,\n                        rate=input.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=input.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m,\n                        eta=eta_str.split('.')[0],\n                        gpu_memory=int(gpu_memory/(1024**2))\n                    )\n                )\n                \n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        input,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n        \n        end = time.time()\n        # end for\n\n    epoch_time = time.time() - end_epoch\n    \n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n    \n    train_info_dict = {\n        'epoch_time(min)': round(epoch_time / 60, 1),\n        'lr': lr,\n        'loss': losses_m.avg,\n    }\n    \n    if is_ds:\n        train_info_dict['loss_aux'] = losses_aux_m.avg\n    \n    return OrderedDict(train_info_dict)\n\n\ndef validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (input, target) in enumerate(loader):\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher:\n                input = input.cuda()\n                target = target.cuda()\n            if args.channels_last:\n                input = input.contiguous(memory_format=torch.channels_last)\n\n            with amp_autocast():\n                output = model(input)\n                if isinstance(output, (tuple, list)):\n                    output = output[0]\n                loss = loss_fn(output, target)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(\n                    0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), input.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n                _logger.info(\n                    '{0}: [{1:>4d}/{2}]  '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        log_name, batch_idx, last_idx, batch_time=batch_time_m,\n                        loss=losses_m, top1=top1_m, top5=top5_m))\n\n    metrics = OrderedDict([(f'loss{log_suffix}', losses_m.avg),\n                           (f'top1{log_suffix}', top1_m.avg),\n                           (f'top5{log_suffix}', top5_m.avg)])\n\n    return metrics\n\n\nif __name__ == '__main__':\n    torch.cuda.empty_cache()\n    args = get_args_parser().parse_args()\n    main(args)\n"
  },
  {
    "path": "validate.py",
    "content": "#!/usr/bin/env python3\n\"\"\" ImageNet Validation Script\n\nThis is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained\nmodels or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes\ncanonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.\n\nHacked together by Ross Wightman (https://github.com/rwightman)\n\"\"\"\nimport os\nimport csv\nimport glob\nimport time\nimport torch\nimport logging\nimport argparse\nfrom tqdm import tqdm\nfrom torch import nn\nfrom contextlib import suppress\nfrom collections import OrderedDict\n\nfrom timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models\nfrom timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet\nfrom timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy\n\nimport models\n\nhas_apex = False\ntry:\n    from apex import amp\n    has_apex = True\nexcept ImportError:\n    pass\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('validate')\n\n\nparser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')\nparser.add_argument('data', metavar='DIR',\n                    help='path to dataset')\nparser.add_argument('--dataset', '-d', metavar='NAME', default='',\n                    help='dataset type (default: ImageFolder/ImageTar if empty)')\nparser.add_argument('--split', metavar='NAME', default='validation',\n                    help='dataset split (default: validation)')\nparser.add_argument('--dataset-download', action='store_true', default=False,\n                    help='Allow download of dataset for torch/ and tfds/ datasets that support it.')\nparser.add_argument('--model', '-m', metavar='NAME', default='dpn92',\n                    help='model architecture (default: dpn92)')\nparser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                    help='number of data loading workers (default: 4)')\nparser.add_argument('-b', '--batch-size', default=256, type=int,\n                    metavar='N', help='mini-batch size (default: 256)')\nparser.add_argument('--img-size', default=None, type=int,\n                    metavar='N', help='Input image dimension, uses model default if empty')\nparser.add_argument('--input-size', default=None, nargs=3, type=int,\n                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='Input image center crop pct')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\nparser.add_argument('--num-classes', type=int, default=None,\n                    help='Number classes in dataset')\nparser.add_argument('--class-map', default='', type=str, metavar='FILENAME',\n                    help='path to class to idx mapping file (default: \"\")')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\nparser.add_argument('--log-freq', default=25, type=int,\n                    metavar='N', help='batch logging frequency (default: 10)')\nparser.add_argument('--checkpoint', default='', type=str, metavar='PATH',\n                    help='path to latest checkpoint (default: none)')\nparser.add_argument('--pretrained', dest='pretrained', action='store_true',\n                    help='use pre-trained model')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--test-pool', dest='test_pool', action='store_true',\n                    help='enable test time pool')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--pin-mem', action='store_true', default=True,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--tf-preprocessing', action='store_true', default=False,\n                    help='Use Tensorflow preprocessing pipeline (require CPU TF installed')\nparser.add_argument('--use-ema', dest='use_ema', action='store_true',\n                    help='use ema version of weights if present')\nparser.add_argument('--torchscript', dest='torchscript', action='store_true',\n                    help='convert model torchscript for inference')\nparser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',\n                    help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')\nparser.add_argument('--results-file', default=None, type=str, metavar='FILENAME',\n                    help='Output csv file for validation results (summary)')\nparser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',\n                    help='Real labels JSON file for imagenet evaluation')\nparser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',\n                    help='Valid label indices txt file for validation of partial label space')\n\n\ndef validate(args):\n    # might as well try to validate something\n    args.pretrained = args.pretrained or not args.checkpoint\n    args.prefetcher = not args.no_prefetcher\n    amp_autocast = suppress  # do nothing\n    if args.amp:\n        if has_native_amp:\n            args.native_amp = True\n        elif has_apex:\n            args.apex_amp = True\n        else:\n            _logger.warning(\"Neither APEX or Native Torch AMP is available.\")\n    assert not args.apex_amp or not args.native_amp, \"Only one AMP mode should be set.\"\n    if args.native_amp:\n        amp_autocast = torch.cuda.amp.autocast\n        _logger.info('Validating in mixed precision with native PyTorch AMP.')\n    elif args.apex_amp:\n        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')\n    else:\n        _logger.info('Validating in float32. AMP not enabled.')\n\n    if args.legacy_jit:\n        set_jit_legacy()\n\n    # create model\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        in_chans=3,\n        global_pool=args.gp,\n        scriptable=args.torchscript)\n    \n    if args.num_classes is None:\n        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'\n        args.num_classes = model.num_classes\n    if args.checkpoint:\n        load_checkpoint(model, args.checkpoint, args.use_ema)\n\n    param_count = sum([m.numel() for m in model.parameters()])\n    _logger.info('Model %s created, param count: %d' % (args.model, param_count))\n\n    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)\n    test_time_pool = False\n    if args.test_pool:\n        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)\n\n    if args.torchscript:\n        torch.jit.optimized_execution(True)\n        model = torch.jit.script(model)\n\n    model = model.cuda()\n    if args.apex_amp:\n        model = amp.initialize(model, opt_level='O1')\n\n    if args.channels_last:\n        model = model.to(memory_format=torch.channels_last)\n\n    if args.num_gpu > 1:\n        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))\n\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    dataset = create_dataset(\n        root=args.data, name=args.dataset, split=args.split,\n        download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)\n\n    if args.valid_labels:\n        with open(args.valid_labels, 'r') as f:\n            valid_labels = {int(line.rstrip()) for line in f}\n            valid_labels = [i in valid_labels for i in range(args.num_classes)]\n    else:\n        valid_labels = None\n\n    if args.real_labels:\n        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)\n    else:\n        real_labels = None\n\n    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']\n    loader = create_loader(\n        dataset,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        crop_pct=crop_pct,\n        pin_memory=args.pin_mem,\n        tf_preprocessing=args.tf_preprocessing)\n\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    model.eval()\n    with torch.no_grad():\n        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non\n        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()\n        if args.channels_last:\n            input = input.contiguous(memory_format=torch.channels_last)\n        model(input)\n        pbar = tqdm(total=len(dataset))\n        end = time.time()\n        \n        for input, target in loader:\n            if args.no_prefetcher:\n                target = target.cuda(non_blocking=args.pin_mem)\n                input = input.cuda(non_blocking=args.pin_mem)\n            if args.channels_last:\n                input = input.contiguous(memory_format=torch.channels_last)\n\n            # compute output\n            with amp_autocast():\n                output = model(input)\n                if valid_labels is not None:\n                    output = output[:, valid_labels]\n                loss = criterion(output, target)\n            if real_labels is not None:\n                real_labels.add_result(output)\n\n            # measure accuracy and record loss\n            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))\n            losses.update(loss.item(), input.size(0))\n            top1.update(acc1.item(), input.size(0))\n            top5.update(acc5.item(), input.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n            pbar.update(args.batch_size)\n            \n    pbar.close()\n    \n    if real_labels is not None:\n        # real labels mode replaces topk values at the end\n        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)\n    else:\n        top1a, top5a = top1.avg, top5.avg\n\n    results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4),\n                          top5=round(top5a, 4), top5_err=round(100 - top5a, 4),\n                          param_count=round(param_count / 1e6, 2),\n                          img_size=data_config['input_size'][-1],\n                          crop_pct=crop_pct,\n                          interpolation=data_config['interpolation'])\n\n    _logger.info(\n        dict(model=args.model,\n             loss=losses.avg,\n             top1=top1.avg,\n             top5=top5.avg,)\n    )\n    _logger.info(f\"Accuracy@top-1 of the network on the {len(dataset)} test images: {top1a:.1f}%\")\n    return results\n\ndef main():\n    setup_default_logging()\n    args = parser.parse_args()\n    model_cfgs = []\n    model_names = []\n    if os.path.isdir(args.checkpoint):\n        # validate all checkpoints in a path with same model\n        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')\n        checkpoints += glob.glob(args.checkpoint + '/*.pth')\n        model_names = list_models(args.model)\n        model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]\n    else:\n        if args.model == 'all':\n            # validate all models in a list of names with pretrained checkpoints\n            args.pretrained = True\n            model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k'])\n            model_cfgs = [(n, '') for n in model_names]\n        elif not is_model(args.model):\n            # model name doesn't exist, try as wildcard filter\n            model_names = list_models(args.model)\n            model_cfgs = [(n, '') for n in model_names]\n\n        if not model_cfgs and os.path.isfile(args.model):\n            with open(args.model) as f:\n                model_names = [line.rstrip() for line in f]\n            model_cfgs = [(n, None) for n in model_names if n]\n\n    if len(model_cfgs):\n        results_file = args.results_file or './results-all.csv'\n        _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))\n        results = []\n        try:\n            start_batch_size = args.batch_size\n            for m, c in model_cfgs:\n                batch_size = start_batch_size\n                args.model = m\n                args.checkpoint = c\n                result = OrderedDict(model=args.model)\n                r = {}\n                while not r and batch_size >= args.num_gpu:\n                    torch.cuda.empty_cache()\n                    try:\n                        args.batch_size = batch_size\n                        print('Validating with batch size: %d' % args.batch_size)\n                        r = validate(args)\n                    except RuntimeError as e:\n                        if batch_size <= args.num_gpu:\n                            print(\"Validation failed with no ability to reduce batch size. Exiting.\")\n                            raise e\n                        batch_size = max(batch_size // 2, args.num_gpu)\n                        print(\"Validation failed, reducing batch size by 50%\")\n                result.update(r)\n                if args.checkpoint:\n                    result['checkpoint'] = args.checkpoint\n                results.append(result)\n        except KeyboardInterrupt as e:\n            pass\n        results = sorted(results, key=lambda x: x['top1'], reverse=True)\n        if len(results):\n            write_results(results_file, results)\n    else:\n        validate(args)\n\ndef write_results(results_file, results):\n    with open(results_file, mode='w') as cf:\n        dw = csv.DictWriter(cf, fieldnames=results[0].keys())\n        dw.writeheader()\n        for r in results:\n            dw.writerow(r)\n        cf.flush()\n\nif __name__ == '__main__':\n    main()"
  }
]