[
  {
    "path": "README.md",
    "content": "# SCTransNet: Spatial-channel Cross Transformer Network for Infrared Small Target Detection [[Paper]](https://ieeexplore.ieee.org/document/10486932) [[Weight]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing) \n\nShuai Yuan, Hanlin Qin, Xiang Yan, Naveed Akhtar, Aimal Main, IEEE Transactions on Geoscience and Remote Sensing 2024.\n\n# SCTransNet 是PRCV 2024、ICPR 2024 Track 1、ICPR 2024 Track 2 三项比赛冠军方案的 Baseline, 同时也是多个优胜算法的Baselines. [[Paper]](https://arxiv.org/abs/2408.09615)\n\n# Bilibili 视频分享\nhttps://www.bilibili.com/video/BV1kr421M7wx/\n\n# 极市平台 推文分享\nhttps://mp.weixin.qq.com/s/H7KLmtFX7j09f-Xc6X1FRw\n\n# If the implementation of this repo is helpful to you, just star it！⭐⭐⭐\n\n# Challenges and inspiration   \n![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture01.png)\n\n# Structure\n![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture2.png)\n\n![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture03.png)\n\n\n# Introduction\n\nWe present a Spatial-channel Cross Transformer Network (SCTransNet) to the IRSTD task. Experiments on both public (e.g., SIRST, NUDT-SIRST, IRSTD-1K) demonstrate the effectiveness of our method. Our main contributions are as follows:\n\n1. We propose SCTransNet, leveraging spatial-channel cross transformer blocks (SCTB)  to predict the context of targets and backgrounds in the deeper network layers.\n\n2. A spatial-embedded single-head channel-cross attention (SSCA) module is utilized to foster semantic interactions across all feature levels and learn the long-range context.\n\n3. We devise a novel complementary feed-forward network (CFN) by crossing spatial-channel information to enhance the semantic difference between the target and background.\n\n\n## Usage\n\n#### 1. Data\n\nThe **SIRST3** dataset, which combines **IRSTD-1K**, **NUDT-SIRST**, and **SIRST-v1**, is used to train SCTransNet.\n* **SIRST-v1** &nbsp; [[download]](https://github.com/YimianDai/sirst) &nbsp; [[paper]](https://arxiv.org/pdf/2009.14530.pdf)\n* **NUDT-SIRST** &nbsp; [[download]](https://github.com/YeRen123455/Infrared-Small-Target-Detection) &nbsp; [[paper]](https://ieeexplore.ieee.org/abstract/document/9864119)\n* **IRSTD-1K** &nbsp; [[download dir]](https://github.com/RuiZhang97/ISNet) &nbsp; [[paper]](https://ieeexplore.ieee.org/document/9880295)\n\n* Apologies for misnaming the **SIRST-v1** dataset as **NUAA-SIRST** in both the article and code. We will follow the original authors’ naming convention in future work.\n\n* **Our project has the following structure:**\n  ```\n  ├──./datasets/\n  │    ├── IRSTD-1K\n  │    │    ├── images\n  │    │    │    ├── XDU0.png\n  │    │    │    ├── XDU1.png\n  │    │    │    ├── ...\n  │    │    ├── masks\n  │    │    │    ├── XDU0.png\n  │    │    │    ├── XDU1.png\n  │    │    │    ├── ...\n  │    │    ├── img_idx\n  │    │    │    ├── train_IRSTD-1K.txt\n  │    │    │    ├── test_IRSTD-1K.txt\n  │    ├── NUDT-SIRST\n  │    │    ├── images\n  │    │    │    ├── 000001.png\n  │    │    │    ├── 000002.png\n  │    │    │    ├── ...\n  │    │    ├── masks\n  │    │    │    ├── 000001.png\n  │    │    │    ├── 000002.png\n  │    │    │    ├── ...\n  │    │    ├── img_idx\n  │    │    │    ├── train_NUDT-SIRST.txt\n  │    │    │    ├── test_NUDT-SIRST.txt\n  │    ├── SIRSTv1 (~which is misnamed as NUAA-SIRST~)\n  │    │    ├── images\n  │    │    │    ├── Misc_1.png\n  │    │    │    ├── Misc_2.png\n  │    │    │    ├── ...\n  │    │    ├── masks\n  │    │    │    ├── Misc_1.png\n  │    │    │    ├── Misc_2.png\n  │    │    │    ├── ...\n  │    │    ├── img_idx\n  │    │    │    ├── train_NUAA-SIRST.txt\n  │    │    │    ├── test_NUAA-SIRST.txt\n  │    ├── SIRST3 (~The sum of SIRSTv1, NUDT-SIRST and IRSTD-1K~)\n  │    │    ├── images\n  │    │    │    ├── XDU0.png\n  │    │    │    ├── XDU1.png\n  │    │    │    ├── ...\n  │    │    ├── masks\n  │    │    │    ├── XDU0.png\n  │    │    │    ├── XDU1.png\n  │    │    │    ├── ...\n  │    │    ├── img_idx\n  │    │    │    ├── train_SIRST3.txt\n  │    │    │    ├── test_SIRST3.txt\n  \n  ```\n\n\n##### 2. Train.\n```bash\npython train.py\n```\n\n#### 3. Test and demo.\n权重文件的百度网盘链接：[https://pan.baidu.com/s/1_hlEaqnJI246GWN4N8k8wA?pwd=t28j](https://pan.baidu.com/s/1B0mANHXSfJaQjHr00XIwgQ?pwd=s7nh)\n\n权重文件的谷歌云盘链接：https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing\n```bash\npython test.py\n```\n\n## Results and Trained Models\n\n#### Qualitative Results\n![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture06.png)\n\n\n\n\n#### Quantitative Results on Mixed SIRSTv1, NUDT-SIRST, and IRSTD-1K. i.e, one weight for three Datasets.\n\n| Model         | mIoU (x10(-2)) | nIoU (x10(-2)) | F-measure (x10(-2))| Pd (x10(-2))|  Fa (x10(-6))|\n| ------------- |:-------------:|:-----:|:-----:|:-----:|:-----:|\n| SIRSTv1    | 77.50  |  81.08 | 87.32 | 96.95 | 13.92 |\n| NUDT-SIRST    | 94.09  |  94.38 | 96.95 | 98.62 | 4.29  | \n| IRSTD-1K      | 68.03  |  68.15 | 80.96 | 93.27 | 10.74 |\n| [[Weights]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing)|\n\n\n*This code is highly borrowed from [IRSTD-Toolbox](https://github.com/XinyiYing/BasicIRSTD). Thanks to Xinyi Ying.\n\n*This code is highly borrowed from [UCTransNet](https://github.com/McGregorWwww/UCTransNet). Thanks to Haonan Wang.\n\n*The overall repository style is highly borrowed from [DNA-Net](https://github.com/YeRen123455/Infrared-Small-Target-Detection). Thanks to Boyang Li.\n\n## Citation\n\nIf you find the code useful, please consider citing our paper using the following BibTeX entry.\n\n```\n@ARTICLE{SCTransNet,\n  author={Yuan, Shuai and Qin, Hanlin and Yan, Xiang and Akhtar, Naveed and Mian, Ajmal},\n  journal={IEEE Transactions on Geoscience and Remote Sensing}, \n  title={SCTransNet: Spatial-Channel Cross Transformer Network for Infrared Small Target Detection}, \n  year={2024},\n  volume={62},\n  number={},\n  pages={1-15},\n  keywords={Semantics;Transformers;Decoding;Feature extraction;Task analysis;Object detection;Visualization;Convolutional neural network (CNN);cross-attention;deep learning;infrared small target detection (IRSTD);transformer},\n  doi={10.1109/TGRS.2024.3383649}}\n\n\n@article{SP-KAN,\ntitle = {SP-KAN: Sparse-sine perception Kolmogorov–Arnold networks for infrared small target detection},\njournal = {ISPRS Journal of Photogrammetry and Remote Sensing},\nvolume = {234},\npages = {1-19},\nyear = {2026},\nissn = {0924-2716},\ndoi = {https://doi.org/10.1016/j.isprsjprs.2026.02.019},\nurl = {https://www.sciencedirect.com/science/article/pii/S0924271626000705},\nauthor = {Shuai Yuan and Yu Liu and Xiaopei Zhang and Xiang Yan and Hanlin Qin and Naveed Akhtar},\n}\n\n\n```\n\n\n## Contact\n**Welcome to raise issues or email to [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our SCTransNet.**\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "dataset.py",
    "content": "from utils import *\r\nimport matplotlib.pyplot as plt\r\nimport os\r\n\r\nos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\r\n\r\n\r\nclass TrainSetLoader(Dataset):\r\n    def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):\r\n        super(TrainSetLoader).__init__()\r\n        self.dataset_name = dataset_name\r\n        self.dataset_dir = dataset_dir + '/' + dataset_name\r\n        self.patch_size = patch_size\r\n        with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:\r\n            self.train_list = f.read().splitlines()\r\n        if img_norm_cfg == None:\r\n            self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)\r\n        else:\r\n            self.img_norm_cfg = img_norm_cfg\r\n        self.tranform = augumentation()\r\n\r\n    def __getitem__(self, idx):\r\n        try:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(\r\n                'I')  # read image base on version ”I“\r\n            # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))\r\n        except:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))\r\n        img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)  # convert PIL to numpy  and  normalize\r\n        mask = np.array(mask, dtype=np.float32) / 255.0\r\n        if len(mask.shape) > 2:\r\n            mask = mask[:, :, 0]\r\n\r\n        # rnd_bn = np.random.normal(0, 0.03)#0.03\r\n        # img += rnd_bn\r\n        #\r\n        # minm = img.min()\r\n        # rng = img.max() - minm\r\n        # gamma = np.random.uniform(0.5, 1.6)\r\n        # x=((img - minm) / rng)\r\n        # img = np.power(x, gamma)\r\n        # img = img * rng + minm\r\n\r\n        img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5)  # 把短的一边先pad至256 把长的一边 随机裁出256  输出 256 256\r\n\r\n        img_patch, mask_patch = self.tranform(img_patch, mask_patch)  # 数据翻转增强\r\n        img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :]  # 升维\r\n        img_patch = torch.from_numpy(np.ascontiguousarray(img_patch))  # numpy 转tensor\r\n        mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch))  # numpy 转tensor\r\n        return img_patch, mask_patch\r\n\r\n    def __len__(self):\r\n        return len(self.train_list)\r\n\r\nclass TrainSetLoader02(Dataset):\r\n    def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):\r\n        super(TrainSetLoader).__init__()\r\n        self.dataset_name = dataset_name\r\n        self.dataset_dir = dataset_dir + '/' + dataset_name\r\n        self.patch_size = patch_size\r\n        with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:\r\n            self.train_list = f.read().splitlines()\r\n        if img_norm_cfg == None:\r\n            self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)\r\n        else:\r\n            self.img_norm_cfg = img_norm_cfg\r\n        self.tranform = augumentation()\r\n\r\n    def __getitem__(self, idx):\r\n        try:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(\r\n                'I')  # read image base on version ”I“\r\n            # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))\r\n        except:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))\r\n        img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)  # convert PIL to numpy  and  normalize\r\n        mask = np.array(mask, dtype=np.float32) / 255.0\r\n        if len(mask.shape) > 2:\r\n            mask = mask[:, :, 0]\r\n\r\n        rnd_bn = np.random.normal(0, 0.03)#0.03\r\n        img += rnd_bn\r\n        #\r\n        # minm = img.min()\r\n        # rng = img.max() - minm\r\n        # gamma = np.random.uniform(0.5, 1.6)\r\n        # x=((img - minm) / rng)\r\n        # img = np.power(x, gamma)\r\n        # img = img * rng + minm\r\n\r\n        img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5)  # 把短的一边先pad至256 把长的一边 随机裁出256  输出 256 256\r\n\r\n        img_patch, mask_patch = self.tranform(img_patch, mask_patch)  # 数据翻转增强\r\n        img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :]  # 升维\r\n        img_patch = torch.from_numpy(np.ascontiguousarray(img_patch))  # numpy 转tensor\r\n        mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch))  # numpy 转tensor\r\n        return img_patch, mask_patch\r\n\r\n    def __len__(self):\r\n        return len(self.train_list)\r\n\r\n\r\nclass TrainSetLoader03(Dataset):\r\n    def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):\r\n        super(TrainSetLoader).__init__()\r\n        self.dataset_name = dataset_name\r\n        self.dataset_dir = dataset_dir + '/' + dataset_name\r\n        self.patch_size = patch_size\r\n        with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:\r\n            self.train_list = f.read().splitlines()\r\n        if img_norm_cfg == None:\r\n            self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)\r\n        else:\r\n            self.img_norm_cfg = img_norm_cfg\r\n        self.tranform = augumentation()\r\n\r\n    def __getitem__(self, idx):\r\n        try:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(\r\n                'I')  # read image base on version ”I“\r\n            # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))\r\n        except:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))\r\n        img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)  # convert PIL to numpy  and  normalize\r\n        mask = np.array(mask, dtype=np.float32) / 255.0\r\n        if len(mask.shape) > 2:\r\n            mask = mask[:, :, 0]\r\n\r\n        # rnd_bn = np.random.normal(0, 0.03)#0.03\r\n        # img += rnd_bn\r\n\r\n        minm = img.min()\r\n        rng = img.max() - minm\r\n        gamma = np.random.uniform(0.5, 1.6)\r\n        x=((img - minm) / rng)\r\n        img = np.power(x, gamma)\r\n        img = img * rng + minm\r\n\r\n        img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5)  # 把短的一边先pad至256 把长的一边 随机裁出256  输出 256 256\r\n\r\n        img_patch, mask_patch = self.tranform(img_patch, mask_patch)  # 数据翻转增强\r\n        img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :]  # 升维\r\n        img_patch = torch.from_numpy(np.ascontiguousarray(img_patch))  # numpy 转tensor\r\n        mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch))  # numpy 转tensor\r\n        return img_patch, mask_patch\r\n\r\n    def __len__(self):\r\n        return len(self.train_list)\r\n\r\n\r\nclass TrainSetLoader04(Dataset):\r\n    def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):\r\n        super(TrainSetLoader).__init__()\r\n        self.dataset_name = dataset_name\r\n        self.dataset_dir = dataset_dir + '/' + dataset_name\r\n        self.patch_size = patch_size\r\n        with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:\r\n            self.train_list = f.read().splitlines()\r\n        if img_norm_cfg == None:\r\n            self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)\r\n        else:\r\n            self.img_norm_cfg = img_norm_cfg\r\n        self.tranform = augumentation()\r\n\r\n    def __getitem__(self, idx):\r\n        try:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(\r\n                'I')  # read image base on version ”I“\r\n            # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))\r\n        except:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))\r\n        img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)  # convert PIL to numpy  and  normalize\r\n        mask = np.array(mask, dtype=np.float32) / 255.0\r\n        if len(mask.shape) > 2:\r\n            mask = mask[:, :, 0]\r\n\r\n        rnd_bn = np.random.normal(0, 0.03)#0.03\r\n        img += rnd_bn\r\n\r\n        minm = img.min()\r\n        rng = img.max() - minm\r\n        gamma = np.random.uniform(0.5, 1.6)\r\n        x=((img - minm) / rng)\r\n        img = np.power(x, gamma)\r\n        img = img * rng + minm\r\n\r\n        img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5)  # 把短的一边先pad至256 把长的一边 随机裁出256  输出 256 256\r\n\r\n        img_patch, mask_patch = self.tranform(img_patch, mask_patch)  # 数据翻转增强\r\n        img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :]  # 升维\r\n        img_patch = torch.from_numpy(np.ascontiguousarray(img_patch))  # numpy 转tensor\r\n        mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch))  # numpy 转tensor\r\n        return img_patch, mask_patch\r\n\r\n    def __len__(self):\r\n        return len(self.train_list)\r\n\r\nclass TestSetLoader(Dataset):\r\n    def __init__(self, dataset_dir, train_dataset_name, test_dataset_name, img_norm_cfg=None):\r\n        super(TestSetLoader).__init__()\r\n        self.dataset_dir = dataset_dir + '/' + test_dataset_name\r\n        with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f:\r\n        # with open(r'D:\\05TGARS\\Upload\\datasets\\SIRST3\\img_idx\\val.txt', 'r') as f:\r\n            self.test_list = f.read().splitlines()\r\n        if img_norm_cfg == None:\r\n            self.img_norm_cfg = get_img_norm_cfg(train_dataset_name, dataset_dir)\r\n        else:\r\n            self.img_norm_cfg = img_norm_cfg\r\n\r\n    def __getitem__(self, idx):\r\n        try:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.png').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.png').replace('//', '/'))\r\n        except:\r\n            img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.bmp').replace('//', '/')).convert('I')\r\n            mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.bmp').replace('//', '/'))\r\n\r\n        img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)\r\n        mask = np.array(mask, dtype=np.float32) / 255.0\r\n        # if mask.shape == (416,608):\r\n        #     print('111')\r\n        if len(mask.shape) > 2:\r\n            mask = mask[:, :, 0]\r\n\r\n        h, w = img.shape\r\n\r\n        img = PadImg(img)\r\n        mask = PadImg(mask)\r\n\r\n        img, mask = img[np.newaxis, :], mask[np.newaxis, :]\r\n\r\n        img = torch.from_numpy(np.ascontiguousarray(img))\r\n        mask = torch.from_numpy(np.ascontiguousarray(mask))\r\n        if img.size() != mask.size():\r\n            print('111')\r\n        return img, mask, [h, w], self.test_list[idx]\r\n\r\n    def __len__(self):\r\n        return len(self.test_list)\r\n\r\n\r\nclass EvalSetLoader(Dataset):\r\n    def __init__(self, dataset_dir, mask_pred_dir, test_dataset_name, model_name):\r\n        super(EvalSetLoader).__init__()\r\n        self.dataset_dir = dataset_dir\r\n        self.mask_pred_dir = mask_pred_dir\r\n        self.test_dataset_name = test_dataset_name\r\n        self.model_name = model_name\r\n        with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f:\r\n            self.test_list = f.read().splitlines()\r\n\r\n    def __getitem__(self, idx):\r\n        mask_pred = Image.open(\r\n            (self.mask_pred_dir + self.test_dataset_name + '/' + self.model_name + '/' + self.test_list[idx] + '.png').replace('//', '/'))\r\n        mask_gt = Image.open(self.dataset_dir + '/masks/' + self.test_list[idx] + '.png')\r\n\r\n        mask_pred = np.array(mask_pred, dtype=np.float32) / 255.0\r\n        mask_gt = np.array(mask_gt, dtype=np.float32) / 255.0\r\n\r\n        if len(mask_pred.shape) == 3:\r\n            mask_pred = mask_pred[:, :, 0]\r\n\r\n        h, w = mask_pred.shape\r\n\r\n        mask_pred, mask_gt = mask_pred[np.newaxis, :], mask_gt[np.newaxis, :]\r\n\r\n        mask_pred = torch.from_numpy(np.ascontiguousarray(mask_pred))\r\n        mask_gt = torch.from_numpy(np.ascontiguousarray(mask_gt))\r\n        return mask_pred, mask_gt, [h, w]\r\n\r\n    def __len__(self):\r\n        return len(self.test_list)\r\n\r\n\r\nclass augumentation(object):\r\n    def __call__(self, input, target):\r\n        if random.random() < 0.5:  # 水平反转\r\n            input = input[::-1, :]\r\n            target = target[::-1, :]\r\n        if random.random() < 0.5:  # 垂直反转\r\n            input = input[:, ::-1]\r\n            target = target[:, ::-1]\r\n        if random.random() < 0.5:  # 转置反转\r\n            input = input.transpose(1, 0)\r\n            target = target.transpose(1, 0)\r\n        return input, target\r\n"
  },
  {
    "path": "datasets/SIRST3/img_idx/test_SIRST3.txt",
    "content": "Misc_338\r\nMisc_379\r\nMisc_422\r\nMisc_73\r\nMisc_321\r\nMisc_162\r\nMisc_372\r\nMisc_185\r\nMisc_420\r\nMisc_143\r\nMisc_137\r\nMisc_224\r\nMisc_274\r\nMisc_156\r\nMisc_121\r\nMisc_240\r\nMisc_82\r\nMisc_230\r\nMisc_51\r\nMisc_203\r\nMisc_210\r\nMisc_219\r\nMisc_178\r\nMisc_416\r\nMisc_366\r\nMisc_306\r\nMisc_94\r\nMisc_5\r\nMisc_413\r\nMisc_279\r\nMisc_110\r\nMisc_298\r\nMisc_257\r\nMisc_199\r\nMisc_368\r\nMisc_34\r\nMisc_248\r\nMisc_303\r\nMisc_99\r\nMisc_225\r\nMisc_317\r\nMisc_308\r\nMisc_66\r\nMisc_266\r\nMisc_123\r\nMisc_288\r\nMisc_348\r\nMisc_59\r\nMisc_299\r\nMisc_28\r\nMisc_227\r\nMisc_54\r\nMisc_287\r\nMisc_252\r\nMisc_21\r\nMisc_130\r\nMisc_228\r\nMisc_177\r\nMisc_24\r\nMisc_395\r\nMisc_332\r\nMisc_207\r\nMisc_30\r\nMisc_173\r\nMisc_49\r\nMisc_48\r\nMisc_75\r\nMisc_427\r\nMisc_411\r\nMisc_205\r\nMisc_57\r\nMisc_273\r\nMisc_272\r\nMisc_125\r\nMisc_262\r\nMisc_14\r\nMisc_56\r\nMisc_234\r\nMisc_412\r\nMisc_275\r\nMisc_350\r\nMisc_418\r\nMisc_113\r\nMisc_148\r\nMisc_357\r\nMisc_239\r\nMisc_385\r\nMisc_154\r\nMisc_83\r\nMisc_222\r\nMisc_111\r\nMisc_153\r\nMisc_309\r\nMisc_374\r\nMisc_312\r\nMisc_117\r\nMisc_197\r\nMisc_292\r\nMisc_145\r\nMisc_376\r\nMisc_122\r\nMisc_101\r\nMisc_394\r\nMisc_136\r\nMisc_289\r\nMisc_131\r\nMisc_25\r\nMisc_74\r\nMisc_347\r\nMisc_15\r\nMisc_1\r\nMisc_361\r\nMisc_373\r\nMisc_151\r\nMisc_97\r\nMisc_3\r\nMisc_381\r\nMisc_400\r\nMisc_277\r\nMisc_326\r\nMisc_38\r\nMisc_241\r\nMisc_334\r\nMisc_397\r\nMisc_164\r\nMisc_300\r\nMisc_297\r\nMisc_414\r\nMisc_249\r\nMisc_91\r\nMisc_107\r\nMisc_68\r\nMisc_259\r\nMisc_342\r\nMisc_189\r\nMisc_76\r\nMisc_235\r\nMisc_335\r\nMisc_22\r\nMisc_319\r\nMisc_314\r\nMisc_79\r\nMisc_346\r\nMisc_251\r\nMisc_387\r\nMisc_410\r\nMisc_345\r\nMisc_236\r\nMisc_41\r\nMisc_35\r\nMisc_9\r\nMisc_396\r\nMisc_375\r\nMisc_176\r\nMisc_188\r\nMisc_44\r\nMisc_356\r\nMisc_80\r\nMisc_60\r\nMisc_386\r\nMisc_126\r\nMisc_206\r\nMisc_307\r\nMisc_208\r\nMisc_276\r\nMisc_286\r\nMisc_193\r\nMisc_270\r\nMisc_377\r\nMisc_37\r\nMisc_390\r\nMisc_106\r\nMisc_255\r\nMisc_328\r\nMisc_268\r\nMisc_301\r\nMisc_181\r\nMisc_371\r\nMisc_204\r\nMisc_323\r\nMisc_87\r\nMisc_53\r\nMisc_293\r\nMisc_20\r\nMisc_85\r\nMisc_327\r\nMisc_216\r\nMisc_365\r\nMisc_408\r\nMisc_62\r\nMisc_271\r\nMisc_265\r\nMisc_325\r\nMisc_340\r\nMisc_155\r\nMisc_10\r\nMisc_13\r\nMisc_124\r\nMisc_140\r\nMisc_139\r\nMisc_46\r\nMisc_19\r\nMisc_135\r\nMisc_221\r\nMisc_212\r\nMisc_359\r\nMisc_237\r\nMisc_331\r\nMisc_69\r\nMisc_186\r\nMisc_294\r\nMisc_196\r\nMisc_84\r\nMisc_380\r\n000541\r\n000832\r\n000710\r\n000316\r\n001142\r\n000409\r\n000226\r\n001133\r\n001255\r\n000558\r\n001000\r\n000763\r\n000385\r\n001319\r\n000591\r\n000484\r\n001236\r\n000334\r\n000915\r\n000523\r\n000765\r\n000215\r\n000045\r\n001303\r\n000644\r\n001190\r\n001026\r\n001089\r\n000720\r\n000195\r\n000519\r\n001111\r\n000922\r\n000965\r\n000080\r\n000751\r\n000281\r\n000849\r\n000495\r\n001257\r\n000966\r\n001267\r\n000091\r\n000935\r\n000314\r\n000128\r\n001169\r\n000701\r\n000903\r\n000144\r\n000613\r\n000313\r\n000552\r\n000787\r\n000536\r\n000423\r\n000389\r\n001235\r\n000615\r\n000366\r\n000399\r\n001155\r\n000625\r\n000610\r\n001050\r\n001148\r\n000440\r\n000032\r\n000812\r\n000830\r\n000771\r\n000324\r\n001161\r\n000209\r\n000645\r\n000897\r\n001036\r\n000200\r\n000708\r\n000474\r\n000840\r\n000554\r\n000011\r\n000603\r\n000560\r\n000550\r\n000498\r\n000243\r\n000433\r\n000106\r\n000036\r\n000574\r\n000837\r\n000608\r\n000341\r\n000863\r\n000907\r\n000600\r\n000361\r\n000555\r\n000636\r\n000509\r\n000799\r\n000062\r\n000297\r\n000933\r\n000374\r\n000928\r\n000910\r\n000131\r\n000191\r\n000216\r\n000373\r\n000556\r\n000417\r\n000485\r\n000557\r\n000153\r\n000955\r\n000511\r\n001038\r\n001239\r\n000986\r\n000740\r\n001018\r\n001068\r\n000902\r\n001221\r\n000510\r\n000212\r\n000026\r\n001320\r\n001074\r\n000189\r\n000335\r\n000295\r\n001265\r\n001206\r\n000637\r\n000349\r\n001122\r\n001178\r\n000944\r\n001135\r\n000113\r\n000041\r\n000575\r\n000757\r\n000882\r\n000917\r\n000909\r\n000569\r\n001167\r\n001212\r\n000801\r\n000466\r\n000449\r\n000913\r\n000607\r\n001218\r\n001096\r\n000850\r\n001151\r\n000426\r\n000908\r\n000286\r\n001280\r\n000167\r\n000332\r\n000499\r\n001066\r\n000611\r\n001188\r\n001238\r\n001195\r\n000980\r\n000194\r\n000315\r\n000391\r\n000704\r\n000930\r\n000150\r\n000687\r\n000394\r\n000988\r\n000535\r\n001278\r\n000867\r\n000371\r\n000355\r\n000722\r\n000178\r\n001215\r\n000747\r\n001162\r\n000001\r\n000405\r\n001097\r\n000506\r\n001102\r\n000326\r\n000737\r\n000302\r\n000699\r\n000715\r\n000874\r\n000713\r\n001244\r\n000589\r\n001144\r\n001213\r\n000689\r\n000266\r\n000563\r\n000395\r\n000415\r\n000749\r\n000330\r\n000300\r\n000071\r\n001264\r\n000237\r\n000973\r\n001076\r\n000666\r\n001079\r\n000532\r\n000820\r\n000207\r\n000182\r\n001176\r\n001272\r\n000047\r\n000268\r\n000906\r\n000351\r\n000633\r\n000025\r\n000957\r\n000176\r\n000220\r\n000188\r\n001006\r\n000926\r\n001129\r\n000599\r\n000684\r\n000155\r\n000703\r\n001143\r\n000620\r\n000075\r\n000839\r\n001300\r\n001186\r\n000197\r\n000811\r\n000601\r\n001040\r\n000735\r\n001029\r\n000623\r\n000436\r\n001242\r\n000813\r\n000493\r\n001053\r\n000410\r\n000936\r\n000473\r\n000478\r\n000458\r\n000709\r\n001307\r\n000781\r\n001023\r\n001004\r\n000359\r\n000287\r\n000482\r\n000941\r\n000067\r\n000284\r\n000447\r\n000894\r\n000163\r\n000059\r\n001222\r\n001268\r\n000744\r\n000921\r\n000173\r\n000378\r\n000806\r\n000629\r\n000870\r\n000860\r\n000660\r\n000664\r\n000533\r\n000617\r\n001145\r\n000961\r\n000940\r\n000365\r\n001003\r\n000305\r\n001317\r\n001314\r\n000352\r\n000898\r\n000883\r\n001309\r\n000732\r\n000336\r\n000622\r\n001232\r\n000166\r\n000392\r\n000231\r\n001281\r\n001199\r\n000893\r\n000879\r\n000084\r\n001174\r\n000377\r\n001063\r\n000628\r\n000537\r\n000772\r\n000100\r\n000043\r\n000291\r\n000918\r\n000548\r\n000729\r\n000475\r\n000731\r\n000974\r\n001047\r\n000595\r\n000508\r\n001101\r\n001248\r\n001107\r\n001127\r\n001253\r\n000920\r\n000021\r\n000742\r\n000140\r\n000320\r\n001209\r\n000434\r\n000120\r\n000711\r\n000137\r\n000249\r\n000995\r\n001084\r\n001273\r\n000845\r\n000452\r\n001220\r\n000553\r\n000003\r\n000984\r\n000086\r\n000967\r\n000788\r\n000416\r\n000138\r\n000112\r\n001114\r\n001308\r\n000053\r\n000762\r\n000333\r\n001081\r\n001287\r\n000821\r\n001302\r\n000480\r\n000775\r\n000525\r\n001262\r\n000124\r\n001069\r\n000782\r\n000490\r\n000010\r\n000465\r\n001286\r\n001060\r\n000685\r\n000717\r\n000939\r\n000027\r\n000971\r\n000727\r\n000170\r\n000187\r\n000671\r\n000976\r\n000562\r\n000651\r\n000831\r\n000397\r\n000157\r\n000682\r\n001227\r\n001010\r\n000294\r\n000159\r\n001088\r\n001241\r\n001229\r\n000289\r\n000169\r\n000135\r\n000090\r\n000456\r\n000691\r\n000382\r\n000483\r\n000528\r\n000889\r\n000060\r\n001326\r\n000520\r\n000269\r\n000706\r\n001020\r\n000502\r\n000492\r\n001153\r\n000521\r\n000496\r\n000538\r\n000009\r\n001294\r\n000542\r\n000127\r\n000192\r\n001237\r\n000406\r\n001055\r\n001251\r\n001315\r\n000151\r\n001180\r\n001250\r\n000901\r\n000816\r\n000273\r\n000545\r\n000912\r\n000583\r\n000285\r\n000678\r\n000992\r\n000386\r\n000022\r\n000573\r\n001075\r\n000927\r\n001304\r\n000219\r\n001291\r\n000786\r\n000348\r\n000861\r\n000785\r\n000875\r\n001191\r\n000853\r\n001008\r\n000817\r\n000721\r\n000807\r\n000122\r\n001083\r\n000261\r\n000588\r\n000587\r\n000880\r\n001311\r\n000835\r\n001021\r\n000714\r\n000081\r\n000420\r\n001103\r\n000142\r\n000270\r\n001160\r\n001224\r\n000792\r\n000019\r\n000470\r\n000529\r\n001141\r\n000646\r\n000656\r\n000815\r\n000885\r\n000924\r\n001254\r\n000111\r\n000598\r\n001120\r\n001087\r\n000020\r\n000214\r\n000724\r\n000582\r\n001234\r\n000983\r\n000759\r\n001149\r\n001044\r\n000665\r\n000107\r\n001322\r\n001210\r\n000448\r\n000260\r\n000579\r\n000996\r\n001246\r\n000698\r\n000783\r\n000240\r\n000862\r\n000616\r\n000051\r\n000400\r\n000455\r\n000809\r\n001013\r\n000110\r\n000634\r\n000838\r\n001184\r\n001077\r\n000694\r\n000412\r\n000836\r\n000344\r\n001240\r\n000259\r\n000439\r\n000766\r\n001202\r\n000093\r\n000507\r\n000303\r\n000262\r\n000092\r\n000890\r\n000789\r\n000784\r\n000272\r\n000186\r\n000500\r\n000734\r\n000887\r\n000476\r\n000960\r\n000202\r\n001034\r\n001214\r\n000516\r\n000808\r\n000346\r\n000680\r\n001289\r\n001042\r\n000802\r\n001139\r\n000640\r\n000803\r\n000329\r\n001123\r\n000761\r\n001028\r\n000650\r\n000133\r\n000592\r\n000148\r\n000866\r\n000230\r\n001012\r\n000055\r\n000252\r\n000421\r\n000425\r\n000596\r\n000296\r\n001325\r\n000255\r\n001092\r\n000362\r\n000450\r\n001183\r\n000046\r\n000648\r\n000370\r\n001128\r\n001175\r\n000730\r\n001154\r\n000963\r\n001098\r\n000654\r\n001112\r\n000319\r\n000398\r\n001298\r\n001057\r\n001125\r\n000012\r\n000756\r\n000141\r\n000931\r\n000581\r\n000463\r\n001259\r\n000985\r\n000076\r\n000632\r\n000162\r\n000954\r\n000779\r\n000233\r\n000614\r\n001126\r\n001288\r\n001283\r\n000790\r\n000023\r\n000380\r\n000661\r\n000923\r\n000225\r\n000156\r\n000390\r\n000597\r\n000468\r\n000606\r\n000307\r\n000630\r\n001223\r\n000038\r\n000224\r\n000673\r\n001193\r\n000016\r\n000015\r\n001110\r\n001204\r\n000683\r\n000353\r\nXDU189\r\nXDU935\r\nXDU672\r\nXDU231\r\nXDU818\r\nXDU888\r\nXDU146\r\nXDU48\r\nXDU492\r\nXDU241\r\nXDU195\r\nXDU801\r\nXDU104\r\nXDU637\r\nXDU996\r\nXDU482\r\nXDU406\r\nXDU889\r\nXDU558\r\nXDU117\r\nXDU777\r\nXDU134\r\nXDU223\r\nXDU943\r\nXDU762\r\nXDU662\r\nXDU54\r\nXDU685\r\nXDU167\r\nXDU489\r\nXDU505\r\nXDU527\r\nXDU817\r\nXDU253\r\nXDU193\r\nXDU597\r\nXDU151\r\nXDU404\r\nXDU596\r\nXDU97\r\nXDU321\r\nXDU279\r\nXDU93\r\nXDU205\r\nXDU9\r\nXDU219\r\nXDU674\r\nXDU501\r\nXDU316\r\nXDU343\r\nXDU885\r\nXDU426\r\nXDU485\r\nXDU850\r\nXDU516\r\nXDU216\r\nXDU160\r\nXDU176\r\nXDU504\r\nXDU883\r\nXDU244\r\nXDU919\r\nXDU781\r\nXDU369\r\nXDU398\r\nXDU441\r\nXDU75\r\nXDU240\r\nXDU805\r\nXDU108\r\nXDU709\r\nXDU352\r\nXDU747\r\nXDU209\r\nXDU845\r\nXDU557\r\nXDU775\r\nXDU56\r\nXDU657\r\nXDU753\r\nXDU788\r\nXDU682\r\nXDU794\r\nXDU877\r\nXDU421\r\nXDU733\r\nXDU546\r\nXDU999\r\nXDU5\r\nXDU63\r\nXDU966\r\nXDU922\r\nXDU789\r\nXDU295\r\nXDU863\r\nXDU578\r\nXDU743\r\nXDU46\r\nXDU115\r\nXDU876\r\nXDU932\r\nXDU289\r\nXDU855\r\nXDU933\r\nXDU517\r\nXDU329\r\nXDU3\r\nXDU451\r\nXDU694\r\nXDU878\r\nXDU259\r\nXDU708\r\nXDU442\r\nXDU829\r\nXDU833\r\nXDU648\r\nXDU381\r\nXDU868\r\nXDU803\r\nXDU673\r\nXDU415\r\nXDU997\r\nXDU667\r\nXDU968\r\nXDU169\r\nXDU525\r\nXDU164\r\nXDU704\r\nXDU711\r\nXDU111\r\nXDU354\r\nXDU927\r\nXDU758\r\nXDU87\r\nXDU697\r\nXDU957\r\nXDU49\r\nXDU563\r\nXDU954\r\nXDU45\r\nXDU429\r\nXDU902\r\nXDU302\r\nXDU523\r\nXDU41\r\nXDU816\r\nXDU785\r\nXDU759\r\nXDU872\r\nXDU185\r\nXDU881\r\nXDU447\r\nXDU129\r\nXDU614\r\nXDU920\r\nXDU334\r\nXDU257\r\nXDU892\r\nXDU103\r\nXDU698\r\nXDU862\r\nXDU33\r\nXDU416\r\nXDU40\r\nXDU715\r\nXDU203\r\nXDU589\r\nXDU142\r\nXDU50\r\nXDU455\r\nXDU620\r\nXDU67\r\nXDU371\r\nXDU192\r\nXDU28\r\nXDU43\r\nXDU661\r\nXDU692\r\nXDU463\r\nXDU745\r\nXDU258\r\nXDU842\r\nXDU459\r\nXDU147\r\nXDU319\r\nXDU225\r\nXDU178\r\nXDU567\r\nXDU925\r\nXDU394\r\nXDU110\r\nXDU663\r\nXDU376\r\nXDU450\r\nXDU10\r\nXDU955\r\nXDU374\r\nXDU278\r\nXDU393\r\nXDU570\r\nXDU217\r\n"
  },
  {
    "path": "datasets/SIRST3/img_idx/train_SIRST3.txt",
    "content": "Misc_119\r\nMisc_64\r\nMisc_90\r\nMisc_364\r\nMisc_250\r\nMisc_351\r\nMisc_39\r\nMisc_313\r\nMisc_179\r\nMisc_344\r\nMisc_421\r\nMisc_398\r\nMisc_417\r\nMisc_95\r\nMisc_339\r\nMisc_426\r\nMisc_269\r\nMisc_316\r\nMisc_419\r\nMisc_144\r\nMisc_149\r\nMisc_146\r\nMisc_31\r\nMisc_58\r\nMisc_4\r\nMisc_264\r\nMisc_283\r\nMisc_284\r\nMisc_150\r\nMisc_220\r\nMisc_133\r\nMisc_77\r\nMisc_70\r\nMisc_425\r\nMisc_195\r\nMisc_304\r\nMisc_329\r\nMisc_65\r\nMisc_167\r\nMisc_174\r\nMisc_202\r\nMisc_157\r\nMisc_96\r\nMisc_320\r\nMisc_369\r\nMisc_109\r\nMisc_16\r\nMisc_40\r\nMisc_295\r\nMisc_147\r\nMisc_247\r\nMisc_423\r\nMisc_152\r\nMisc_100\r\nMisc_263\r\nMisc_352\r\nMisc_233\r\nMisc_190\r\nMisc_392\r\nMisc_281\r\nMisc_358\r\nMisc_163\r\nMisc_132\r\nMisc_405\r\nMisc_159\r\nMisc_12\r\nMisc_367\r\nMisc_172\r\nMisc_401\r\nMisc_138\r\nMisc_104\r\nMisc_86\r\nMisc_160\r\nMisc_242\r\nMisc_7\r\nMisc_305\r\nMisc_243\r\nMisc_399\r\nMisc_363\r\nMisc_61\r\nMisc_129\r\nMisc_330\r\nMisc_134\r\nMisc_315\r\nMisc_180\r\nMisc_244\r\nMisc_63\r\nMisc_391\r\nMisc_42\r\nMisc_404\r\nMisc_29\r\nMisc_238\r\nMisc_285\r\nMisc_214\r\nMisc_93\r\nMisc_253\r\nMisc_402\r\nMisc_50\r\nMisc_291\r\nMisc_128\r\nMisc_267\r\nMisc_115\r\nMisc_337\r\nMisc_370\r\nMisc_158\r\nMisc_114\r\nMisc_388\r\nMisc_170\r\nMisc_354\r\nMisc_36\r\nMisc_424\r\nMisc_336\r\nMisc_393\r\nMisc_229\r\nMisc_108\r\nMisc_105\r\nMisc_406\r\nMisc_2\r\nMisc_324\r\nMisc_47\r\nMisc_200\r\nMisc_187\r\nMisc_33\r\nMisc_72\r\nMisc_384\r\nMisc_120\r\nMisc_322\r\nMisc_360\r\nMisc_192\r\nMisc_112\r\nMisc_142\r\nMisc_403\r\nMisc_169\r\nMisc_223\r\nMisc_213\r\nMisc_161\r\nMisc_256\r\nMisc_141\r\nMisc_78\r\nMisc_296\r\nMisc_6\r\nMisc_258\r\nMisc_231\r\nMisc_52\r\nMisc_183\r\nMisc_362\r\nMisc_102\r\nMisc_88\r\nMisc_343\r\nMisc_341\r\nMisc_118\r\nMisc_165\r\nMisc_280\r\nMisc_17\r\nMisc_290\r\nMisc_67\r\nMisc_382\r\nMisc_191\r\nMisc_166\r\nMisc_8\r\nMisc_45\r\nMisc_415\r\nMisc_349\r\nMisc_98\r\nMisc_127\r\nMisc_184\r\nMisc_310\r\nMisc_198\r\nMisc_254\r\nMisc_211\r\nMisc_103\r\nMisc_232\r\nMisc_218\r\nMisc_89\r\nMisc_201\r\nMisc_11\r\nMisc_168\r\nMisc_215\r\nMisc_383\r\nMisc_333\r\nMisc_245\r\nMisc_55\r\nMisc_27\r\nMisc_226\r\nMisc_116\r\nMisc_378\r\nMisc_355\r\nMisc_302\r\nMisc_209\r\nMisc_32\r\nMisc_23\r\nMisc_261\r\nMisc_182\r\nMisc_282\r\nMisc_409\r\nMisc_260\r\nMisc_194\r\nMisc_407\r\nMisc_175\r\nMisc_278\r\nMisc_26\r\nMisc_246\r\nMisc_217\r\nMisc_311\r\nMisc_43\r\nMisc_353\r\nMisc_81\r\nMisc_18\r\nMisc_318\r\nMisc_389\r\nMisc_171\r\nMisc_71\r\nMisc_92\r\n001137\r\n000345\r\n000774\r\n000593\r\n001002\r\n001024\r\n001285\r\n000649\r\n000282\r\n000037\r\n001150\r\n001124\r\n001185\r\n000718\r\n000968\r\n001216\r\n000494\r\n001015\r\n000945\r\n000446\r\n000193\r\n000746\r\n000733\r\n000621\r\n001131\r\n000085\r\n000064\r\n000312\r\n000213\r\n001297\r\n001039\r\n000066\r\n000002\r\n000843\r\n000659\r\n000298\r\n000227\r\n000951\r\n000299\r\n000937\r\n000547\r\n000948\r\n000859\r\n000058\r\n000422\r\n001140\r\n000873\r\n001249\r\n000822\r\n000745\r\n000609\r\n000881\r\n001271\r\n000851\r\n000848\r\n001016\r\n001100\r\n000004\r\n001208\r\n000099\r\n000073\r\n000049\r\n000158\r\n001279\r\n000063\r\n001181\r\n000515\r\n001054\r\n000844\r\n001177\r\n000934\r\n000663\r\n000841\r\n000205\r\n000911\r\n000211\r\n000267\r\n000379\r\n000168\r\n000301\r\n001061\r\n001299\r\n000818\r\n001070\r\n001305\r\n001095\r\n000695\r\n000375\r\n000061\r\n000669\r\n000325\r\n000864\r\n000174\r\n000347\r\n000748\r\n000946\r\n001274\r\n001059\r\n000253\r\n000679\r\n001041\r\n000999\r\n000794\r\n000576\r\n001094\r\n000635\r\n000825\r\n000013\r\n000627\r\n000487\r\n001201\r\n000457\r\n000916\r\n000030\r\n000109\r\n000693\r\n000183\r\n000467\r\n001164\r\n001258\r\n000970\r\n000263\r\n000858\r\n000668\r\n000526\r\n000736\r\n001146\r\n000247\r\n000367\r\n000050\r\n000773\r\n000318\r\n001163\r\n001159\r\n001306\r\n001310\r\n000571\r\n001031\r\n000461\r\n000652\r\n000444\r\n001194\r\n001156\r\n000014\r\n000250\r\n000134\r\n001301\r\n000768\r\n000823\r\n000040\r\n000891\r\n000561\r\n001200\r\n001233\r\n000119\r\n001324\r\n000793\r\n000293\r\n000143\r\n001011\r\n000834\r\n000095\r\n000017\r\n001116\r\n000653\r\n000755\r\n000705\r\n000723\r\n000738\r\n001266\r\n000364\r\n000791\r\n000048\r\n000228\r\n000956\r\n000570\r\n000116\r\n000350\r\n000276\r\n001022\r\n000369\r\n001225\r\n000129\r\n000947\r\n000800\r\n000154\r\n001168\r\n000602\r\n000210\r\n001284\r\n000534\r\n000814\r\n000846\r\n001132\r\n000257\r\n000688\r\n001056\r\n000430\r\n000900\r\n000471\r\n000833\r\n000171\r\n000567\r\n001043\r\n000798\r\n001085\r\n000895\r\n001121\r\n000979\r\n000363\r\n000311\r\n000229\r\n000234\r\n001073\r\n000716\r\n001005\r\n001119\r\n001327\r\n000754\r\n000245\r\n000780\r\n000277\r\n001052\r\n000856\r\n000871\r\n000531\r\n000472\r\n000083\r\n000274\r\n000147\r\n000096\r\n000152\r\n001158\r\n000453\r\n000057\r\n000411\r\n000388\r\n000082\r\n000655\r\n000962\r\n001086\r\n000539\r\n000381\r\n000527\r\n000309\r\n000065\r\n000418\r\n000306\r\n000469\r\n000088\r\n001007\r\n000690\r\n000686\r\n000428\r\n000605\r\n001067\r\n001025\r\n001245\r\n001045\r\n000767\r\n000728\r\n001065\r\n000540\r\n000658\r\n000549\r\n000522\r\n000978\r\n000317\r\n000117\r\n000184\r\n000631\r\n000459\r\n000804\r\n000819\r\n001048\r\n000271\r\n000358\r\n000670\r\n000145\r\n000118\r\n000239\r\n000126\r\n000115\r\n001171\r\n000497\r\n000847\r\n000462\r\n000604\r\n000223\r\n000810\r\n000938\r\n000460\r\n001001\r\n001108\r\n000826\r\n001252\r\n000981\r\n000429\r\n001277\r\n001080\r\n000707\r\n000221\r\n000265\r\n000672\r\n000752\r\n000481\r\n000384\r\n000236\r\n000125\r\n000892\r\n000238\r\n000639\r\n000914\r\n000432\r\n000489\r\n000513\r\n000114\r\n000719\r\n000643\r\n000994\r\n000712\r\n000242\r\n001207\r\n000201\r\n001312\r\n000524\r\n000198\r\n000543\r\n000028\r\n001323\r\n000331\r\n000222\r\n000869\r\n000692\r\n000842\r\n000795\r\n001078\r\n000896\r\n000052\r\n001256\r\n000292\r\n000070\r\n000443\r\n000343\r\n000504\r\n000203\r\n000204\r\n000943\r\n000886\r\n000337\r\n000146\r\n000778\r\n001051\r\n000121\r\n000196\r\n000647\r\n001182\r\n000972\r\n000750\r\n000568\r\n000585\r\n000925\r\n001295\r\n000280\r\n000741\r\n000360\r\n000969\r\n000042\r\n000087\r\n000975\r\n000877\r\n000942\r\n000998\r\n000356\r\n000697\r\n000403\r\n000208\r\n000514\r\n000401\r\n000248\r\n000758\r\n001192\r\n001282\r\n000876\r\n000764\r\n000445\r\n000865\r\n000564\r\n000149\r\n000584\r\n001019\r\n001017\r\n000308\r\n001318\r\n000018\r\n000056\r\n000232\r\n000777\r\n000328\r\n001217\r\n001219\r\n000577\r\n000383\r\n001276\r\n000929\r\n000354\r\n000029\r\n001118\r\n000006\r\n000105\r\n000486\r\n000074\r\n000275\r\n001189\r\n000612\r\n000008\r\n001093\r\n000034\r\n000949\r\n000039\r\n001138\r\n000677\r\n001198\r\n000696\r\n001230\r\n000590\r\n001243\r\n000323\r\n001032\r\n000488\r\n000413\r\n000700\r\n000854\r\n000217\r\n000101\r\n000518\r\n000578\r\n000950\r\n000031\r\n000888\r\n000256\r\n000419\r\n000035\r\n000872\r\n001147\r\n001033\r\n000991\r\n001113\r\n000884\r\n001292\r\n000235\r\n000185\r\n000404\r\n001071\r\n000340\r\n000387\r\n000990\r\n000254\r\n000953\r\n000770\r\n000241\r\n000559\r\n000357\r\n000905\r\n000068\r\n000097\r\n000108\r\n000618\r\n000982\r\n000743\r\n000279\r\n000824\r\n001058\r\n000681\r\n001196\r\n001275\r\n000760\r\n000089\r\n000054\r\n000624\r\n000987\r\n001179\r\n000338\r\n000206\r\n000177\r\n000565\r\n001134\r\n000130\r\n000952\r\n001082\r\n000136\r\n000551\r\n000288\r\n000619\r\n000852\r\n001165\r\n000321\r\n001231\r\n000501\r\n000393\r\n000372\r\n000512\r\n000424\r\n001211\r\n000977\r\n000580\r\n000626\r\n000218\r\n000451\r\n000566\r\n000165\r\n000964\r\n000407\r\n000517\r\n000572\r\n001027\r\n001106\r\n000662\r\n000102\r\n000776\r\n000160\r\n000072\r\n001062\r\n000899\r\n001091\r\n000024\r\n000304\r\n000753\r\n000594\r\n000769\r\n000464\r\n000044\r\n000007\r\n000438\r\n001313\r\n000069\r\n001187\r\n000993\r\n000161\r\n000505\r\n001290\r\n000546\r\n000258\r\n001064\r\n001228\r\n001115\r\n000674\r\n001293\r\n001269\r\n001172\r\n001247\r\n000667\r\n000641\r\n000172\r\n000503\r\n000181\r\n000726\r\n000033\r\n000437\r\n001263\r\n000989\r\n000376\r\n000959\r\n000491\r\n001090\r\n000530\r\n001104\r\n000796\r\n000290\r\n000179\r\n000868\r\n000264\r\n000327\r\n000657\r\n000442\r\n001035\r\n001136\r\n000251\r\n000342\r\n000098\r\n001166\r\n000829\r\n001049\r\n000104\r\n000402\r\n000339\r\n000427\r\n000078\r\n001014\r\n000139\r\n001152\r\n000479\r\n000175\r\n000435\r\n001072\r\n000246\r\n000414\r\n000638\r\n000904\r\n001321\r\n000396\r\n000094\r\n000805\r\n000005\r\n001316\r\n001046\r\n000586\r\n001226\r\n000079\r\n000725\r\n001296\r\n000827\r\n000103\r\n001099\r\n001205\r\n000368\r\n000278\r\n000190\r\n000544\r\n001260\r\n000997\r\n000431\r\n000919\r\n001197\r\n001030\r\n001173\r\n000454\r\n001157\r\n000164\r\n001037\r\n000077\r\n000642\r\n000828\r\n000675\r\n000702\r\n000310\r\n000797\r\n000739\r\n000123\r\n001270\r\n001170\r\n001117\r\n000857\r\n000477\r\n000180\r\n001130\r\n000958\r\n000855\r\n001109\r\n001009\r\n000199\r\n001203\r\n000676\r\n000132\r\n000441\r\n001261\r\n000878\r\n000283\r\n000408\r\n000244\r\n000322\r\n001105\r\n000932\r\nXDU514\r\nXDU646\r\nXDU904\r\nXDU660\r\nXDU347\r\nXDU962\r\nXDU92\r\nXDU838\r\nXDU907\r\nXDU496\r\nXDU83\r\nXDU606\r\nXDU307\r\nXDU138\r\nXDU357\r\nXDU993\r\nXDU693\r\nXDU493\r\nXDU891\r\nXDU410\r\nXDU288\r\nXDU562\r\nXDU849\r\nXDU23\r\nXDU199\r\nXDU370\r\nXDU537\r\nXDU871\r\nXDU656\r\nXDU331\r\nXDU328\r\nXDU403\r\nXDU230\r\nXDU529\r\nXDU229\r\nXDU476\r\nXDU792\r\nXDU412\r\nXDU689\r\nXDU51\r\nXDU532\r\nXDU356\r\nXDU303\r\nXDU161\r\nXDU879\r\nXDU867\r\nXDU773\r\nXDU323\r\nXDU836\r\nXDU236\r\nXDU749\r\nXDU807\r\nXDU72\r\nXDU128\r\nXDU822\r\nXDU480\r\nXDU270\r\nXDU651\r\nXDU815\r\nXDU30\r\nXDU548\r\nXDU386\r\nXDU555\r\nXDU122\r\nXDU798\r\nXDU264\r\nXDU725\r\nXDU806\r\nXDU440\r\nXDU332\r\nXDU875\r\nXDU325\r\nXDU486\r\nXDU659\r\nXDU835\r\nXDU335\r\nXDU330\r\nXDU132\r\nXDU89\r\nXDU580\r\nXDU8\r\nXDU4\r\nXDU280\r\nXDU895\r\nXDU869\r\nXDU799\r\nXDU419\r\nXDU772\r\nXDU896\r\nXDU604\r\nXDU116\r\nXDU85\r\nXDU91\r\nXDU158\r\nXDU130\r\nXDU611\r\nXDU98\r\nXDU299\r\nXDU114\r\nXDU923\r\nXDU94\r\nXDU800\r\nXDU934\r\nXDU998\r\nXDU338\r\nXDU959\r\nXDU712\r\nXDU754\r\nXDU636\r\nXDU624\r\nXDU918\r\nXDU26\r\nXDU559\r\nXDU761\r\nXDU909\r\nXDU340\r\nXDU262\r\nXDU964\r\nXDU29\r\nXDU443\r\nXDU929\r\nXDU739\r\nXDU654\r\nXDU858\r\nXDU551\r\nXDU365\r\nXDU665\r\nXDU705\r\nXDU434\r\nXDU105\r\nXDU887\r\nXDU910\r\nXDU261\r\nXDU360\r\nXDU912\r\nXDU894\r\nXDU727\r\nXDU42\r\nXDU556\r\nXDU613\r\nXDU285\r\nXDU668\r\nXDU592\r\nXDU341\r\nXDU942\r\nXDU982\r\nXDU191\r\nXDU528\r\nXDU720\r\nXDU601\r\nXDU531\r\nXDU21\r\nXDU275\r\nXDU232\r\nXDU344\r\nXDU301\r\nXDU977\r\nXDU390\r\nXDU938\r\nXDU975\r\nXDU688\r\nXDU979\r\nXDU58\r\nXDU494\r\nXDU988\r\nXDU826\r\nXDU571\r\nXDU1\r\nXDU779\r\nXDU633\r\nXDU987\r\nXDU449\r\nXDU313\r\nXDU153\r\nXDU718\r\nXDU978\r\nXDU538\r\nXDU950\r\nXDU740\r\nXDU986\r\nXDU780\r\nXDU965\r\nXDU900\r\nXDU423\r\nXDU194\r\nXDU490\r\nXDU73\r\nXDU675\r\nXDU254\r\nXDU547\r\nXDU983\r\nXDU227\r\nXDU68\r\nXDU445\r\nXDU903\r\nXDU652\r\nXDU642\r\nXDU599\r\nXDU64\r\nXDU221\r\nXDU291\r\nXDU460\r\nXDU623\r\nXDU766\r\nXDU680\r\nXDU714\r\nXDU25\r\nXDU19\r\nXDU65\r\nXDU671\r\nXDU333\r\nXDU375\r\nXDU550\r\nXDU790\r\nXDU397\r\nXDU497\r\nXDU645\r\nXDU470\r\nXDU913\r\nXDU956\r\nXDU218\r\nXDU380\r\nXDU487\r\nXDU66\r\nXDU324\r\nXDU612\r\nXDU384\r\nXDU544\r\nXDU144\r\nXDU542\r\nXDU248\r\nXDU461\r\nXDU148\r\nXDU653\r\nXDU336\r\nXDU866\r\nXDU456\r\nXDU540\r\nXDU351\r\nXDU112\r\nXDU272\r\nXDU53\r\nXDU515\r\nXDU699\r\nXDU417\r\nXDU639\r\nXDU342\r\nXDU31\r\nXDU448\r\nXDU292\r\nXDU728\r\nXDU180\r\nXDU149\r\nXDU706\r\nXDU973\r\nXDU320\r\nXDU625\r\nXDU811\r\nXDU462\r\nXDU388\r\nXDU638\r\nXDU90\r\nXDU109\r\nXDU722\r\nXDU162\r\nXDU972\r\nXDU767\r\nXDU349\r\nXDU263\r\nXDU465\r\nXDU576\r\nXDU507\r\nXDU644\r\nXDU587\r\nXDU255\r\nXDU326\r\nXDU500\r\nXDU586\r\nXDU524\r\nXDU765\r\nXDU890\r\nXDU960\r\nXDU594\r\nXDU80\r\nXDU14\r\nXDU569\r\nXDU953\r\nXDU884\r\nXDU282\r\nXDU387\r\nXDU579\r\nXDU260\r\nXDU252\r\nXDU971\r\nXDU905\r\nXDU994\r\nXDU36\r\nXDU266\r\nXDU405\r\nXDU208\r\nXDU207\r\nXDU723\r\nXDU35\r\nXDU590\r\nXDU678\r\nXDU629\r\nXDU939\r\nXDU804\r\nXDU948\r\nXDU24\r\nXDU967\r\nXDU293\r\nXDU545\r\nXDU901\r\nXDU290\r\nXDU989\r\nXDU322\r\nXDU707\r\nXDU188\r\nXDU582\r\nXDU810\r\nXDU439\r\nXDU300\r\nXDU237\r\nXDU457\r\nXDU433\r\nXDU831\r\nXDU917\r\nXDU677\r\nXDU82\r\nXDU561\r\nXDU413\r\nXDU834\r\nXDU368\r\nXDU658\r\nXDU898\r\nXDU173\r\nXDU34\r\nXDU467\r\nXDU2\r\nXDU265\r\nXDU735\r\nXDU454\r\nXDU163\r\nXDU81\r\nXDU74\r\nXDU140\r\nXDU166\r\nXDU478\r\nXDU61\r\nXDU880\r\nXDU841\r\nXDU565\r\nXDU377\r\nXDU839\r\nXDU691\r\nXDU607\r\nXDU530\r\nXDU844\r\nXDU847\r\nXDU472\r\nXDU882\r\nXDU17\r\nXDU619\r\nXDU12\r\nXDU736\r\nXDU859\r\nXDU186\r\nXDU985\r\nXDU389\r\nXDU921\r\nXDU355\r\nXDU539\r\nXDU141\r\nXDU916\r\nXDU135\r\nXDU519\r\nXDU621\r\nXDU435\r\nXDU760\r\nXDU783\r\nXDU591\r\nXDU628\r\nXDU464\r\nXDU649\r\nXDU198\r\nXDU560\r\nXDU372\r\nXDU553\r\nXDU458\r\nXDU183\r\nXDU650\r\nXDU622\r\nXDU825\r\nXDU471\r\nXDU190\r\nXDU414\r\nXDU44\r\nXDU741\r\nXDU635\r\nXDU647\r\nXDU573\r\nXDU864\r\nXDU618\r\nXDU364\r\nXDU543\r\nXDU437\r\nXDU502\r\nXDU824\r\nXDU952\r\nXDU125\r\nXDU641\r\nXDU491\r\nXDU201\r\nXDU947\r\nXDU444\r\nXDU79\r\nXDU518\r\nXDU32\r\nXDU283\r\nXDU802\r\nXDU483\r\nXDU59\r\nXDU477\r\nXDU670\r\nXDU234\r\nXDU577\r\nXDU969\r\nXDU669\r\nXDU995\r\nXDU425\r\nXDU506\r\nXDU970\r\nXDU681\r\nXDU353\r\nXDU676\r\nXDU958\r\nXDU782\r\nXDU121\r\nXDU363\r\nXDU411\r\nXDU690\r\nXDU392\r\nXDU536\r\nXDU752\r\nXDU210\r\nXDU774\r\nXDU350\r\nXDU731\r\nXDU930\r\nXDU479\r\nXDU602\r\nXDU856\r\nXDU96\r\nXDU520\r\nXDU13\r\nXDU915\r\nXDU106\r\nXDU853\r\nXDU308\r\nXDU47\r\nXDU769\r\nXDU242\r\nXDU899\r\nXDU484\r\nXDU581\r\nXDU643\r\nXDU827\r\nXDU246\r\nXDU119\r\nXDU746\r\nXDU100\r\nXDU311\r\nXDU512\r\nXDU852\r\nXDU509\r\nXDU874\r\nXDU686\r\nXDU139\r\nXDU928\r\nXDU719\r\nXDU296\r\nXDU750\r\nXDU713\r\nXDU156\r\nXDU634\r\nXDU6\r\nXDU821\r\nXDU488\r\nXDU420\r\nXDU748\r\nXDU126\r\nXDU474\r\nXDU273\r\nXDU742\r\nXDU438\r\nXDU20\r\nXDU27\r\nXDU452\r\nXDU541\r\nXDU598\r\nXDU949\r\nXDU992\r\nXDU513\r\nXDU155\r\nXDU228\r\nXDU206\r\nXDU69\r\nXDU666\r\nXDU860\r\nXDU136\r\nXDU617\r\nXDU436\r\nXDU716\r\nXDU823\r\nXDU38\r\nXDU1000\r\nXDU851\r\nXDU627\r\nXDU974\r\nXDU717\r\nXDU734\r\nXDU796\r\nXDU481\r\nXDU990\r\nXDU679\r\nXDU764\r\nXDU238\r\nXDU848\r\nXDU908\r\nXDU418\r\nXDU696\r\nXDU378\r\nXDU724\r\nXDU632\r\nXDU182\r\nXDU76\r\nXDU791\r\nXDU830\r\nXDU814\r\nXDU306\r\nXDU931\r\nXDU154\r\nXDU564\r\nXDU383\r\nXDU473\r\nXDU84\r\nXDU143\r\nXDU18\r\nXDU683\r\nXDU495\r\nXDU78\r\nXDU840\r\nXDU382\r\nXDU385\r\nXDU233\r\nXDU220\r\nXDU616\r\nXDU655\r\nXDU797\r\nXDU854\r\nXDU312\r\nXDU924\r\nXDU593\r\nXDU150\r\nXDU60\r\nXDU951\r\nXDU812\r\nXDU608\r\nXDU408\r\nXDU184\r\nXDU552\r\nXDU172\r\nXDU820\r\nXDU584\r\nXDU314\r\nXDU702\r\nXDU174\r\nXDU379\r\nXDU511\r\nXDU914\r\nXDU214\r\nXDU315\r\nXDU535\r\nXDU305\r\nXDU294\r\nXDU71\r\nXDU534\r\nXDU133\r\nXDU204\r\nXDU991\r\nXDU475\r\nXDU870\r\nXDU793\r\nXDU585\r\nXDU431\r\nXDU786\r\nXDU944\r\nXDU102\r\nXDU245\r\nXDU857\r\nXDU427\r\nXDU738\r\nXDU726\r\nXDU568\r\nXDU843\r\nXDU131\r\nXDU298\r\nXDU498\r\nXDU837\r\nXDU243\r\nXDU588\r\nXDU832\r\nXDU526\r\nXDU710\r\nXDU177\r\nXDU310\r\nXDU165\r\nXDU603\r\nXDU99\r\nXDU22\r\nXDU615\r\nXDU566\r\nXDU286\r\nXDU703\r\nXDU361\r\nXDU795\r\nXDU277\r\nXDU508\r\nXDU521\r\nXDU297\r\nXDU317\r\nXDU861\r\nXDU271\r\nXDU318\r\nXDU572\r\nXDU247\r\nXDU202\r\nXDU946\r\nXDU466\r\nXDU732\r\nXDU226\r\nXDU583\r\nXDU120\r\nXDU926\r\nXDU401\r\nXDU687\r\nXDU984\r\nXDU819\r\nXDU664\r\nXDU400\r\nXDU446\r\nXDU453\r\nXDU730\r\nXDU362\r\nXDU62\r\nXDU175\r\nXDU809\r\nXDU430\r\nXDU124\r\nXDU346\r\nXDU776\r\nXDU605\r\nXDU187\r\nXDU337\r\nXDU211\r\nXDU684\r\nXDU179\r\nXDU981\r\nXDU784\r\nXDU701\r\nXDU358\r\nXDU768\r\nXDU911\r\nXDU235\r\nXDU215\r\nXDU145\r\nXDU609\r\nXDU281\r\nXDU432\r\nXDU196\r\nXDU499\r\nXDU250\r\nXDU304\r\nXDU600\r\nXDU309\r\nXDU171\r\nXDU787\r\nXDU595\r\nXDU808\r\nXDU7\r\nXDU846\r\nXDU428\r\nXDU287\r\nXDU729\r\nXDU213\r\nXDU828\r\nXDU941\r\nXDU395\r\nXDU756\r\nXDU897\r\nXDU239\r\nXDU610\r\nXDU251\r\nXDU373\r\nXDU533\r\nXDU95\r\nXDU57\r\nXDU945\r\nXDU222\r\nXDU168\r\nXDU137\r\nXDU961\r\nXDU906\r\nXDU937\r\nXDU0\r\nXDU770\r\nXDU268\r\nXDU963\r\nXDU113\r\nXDU771\r\nXDU763\r\nXDU339\r\nXDU52\r\nXDU737\r\nXDU755\r\nXDU159\r\nXDU626\r\nXDU16\r\nXDU118\r\nXDU77\r\nXDU574\r\nXDU402\r\nXDU407\r\nXDU88\r\nXDU778\r\nXDU391\r\nXDU15\r\nXDU940\r\nXDU886\r\nXDU359\r\nXDU424\r\nXDU721\r\nXDU399\r\nXDU345\r\nXDU157\r\nXDU107\r\nXDU873\r\nXDU865\r\nXDU39\r\nXDU893\r\nXDU976\r\nXDU695\r\nXDU367\r\nXDU700\r\nXDU422\r\nXDU936\r\nXDU123\r\nXDU503\r\nXDU366\r\nXDU101\r\nXDU631\r\nXDU276\r\nXDU549\r\nXDU212\r\nXDU197\r\nXDU640\r\nXDU200\r\nXDU37\r\nXDU469\r\nXDU522\r\nXDU575\r\nXDU256\r\nXDU409\r\nXDU152\r\nXDU224\r\nXDU86\r\nXDU630\r\nXDU980\r\nXDU813\r\nXDU70\r\nXDU249\r\nXDU396\r\nXDU11\r\nXDU327\r\nXDU269\r\nXDU284\r\nXDU757\r\nXDU348\r\nXDU554\r\nXDU127\r\nXDU267\r\nXDU751\r\nXDU181\r\nXDU468\r\nXDU274\r\nXDU55\r\nXDU510\r\nXDU170\r\nXDU744\r\n"
  },
  {
    "path": "metrics.py",
    "content": "import numpy as np\r\nimport torch\r\nfrom skimage import measure\r\n\r\n\r\nclass ROCMetric():\r\n    \"\"\"Computes pixAcc and mIoU metric scores\r\n    \"\"\"\r\n\r\n    def __init__(self, nclass, bins):\r\n        # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值\r\n        # nclass :有几个类别 红外弱小目标检测只有一个类别\r\n        super(ROCMetric, self).__init__()\r\n        self.nclass = nclass\r\n        self.bins = bins\r\n        self.tp_arr = np.zeros(self.bins + 1)\r\n        self.pos_arr = np.zeros(self.bins + 1)\r\n        self.fp_arr = np.zeros(self.bins + 1)\r\n        self.neg_arr = np.zeros(self.bins + 1)\r\n        self.class_pos = np.zeros(self.bins + 1)\r\n        # self.reset()\r\n\r\n    # 网络输入的结果和标签 计算两者之前的东西\r\n    def update(self, preds, labels):\r\n        for iBin in range(self.bins + 1):\r\n            # score_thresh = (iBin + 0.0) / self.bins\r\n            score_thresh = -30 + iBin * (255 / self.bins)\r\n            # print(iBin, \"-th, score_thresh: \", score_thresh)\r\n            i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)\r\n            self.tp_arr[iBin] += i_tp\r\n            self.pos_arr[iBin] += i_pos\r\n            self.fp_arr[iBin] += i_fp\r\n            self.neg_arr[iBin] += i_neg\r\n            self.class_pos[iBin] += i_class_pos\r\n\r\n    def get(self):\r\n        tp_rates = self.tp_arr / (self.pos_arr + 0.001)  # tp_rates = recall = TP/(TP+FN)\r\n        fp_rates = self.fp_arr / (self.neg_arr + 0.001)  # fp_rates =  FP/(FP+TN)\r\n        FP = self.fp_arr / (self.neg_arr + self.pos_arr)\r\n        recall = self.tp_arr / (self.pos_arr + 0.001)  # recall = TP/(TP+FN)\r\n        precision = self.tp_arr / (self.class_pos + 0.001)  # precision = TP/(TP+FP)\r\n\r\n        return tp_rates, fp_rates, recall, precision, FP\r\n\r\n    def reset(self):\r\n        self.tp_arr = np.zeros([11])\r\n        self.pos_arr = np.zeros([11])\r\n        self.fp_arr = np.zeros([11])\r\n        self.neg_arr = np.zeros([11])\r\n        self.class_pos = np.zeros([11])\r\n\r\n\r\nclass mIoU():\r\n\r\n    def __init__(self):\r\n        super(mIoU, self).__init__()\r\n        self.reset()\r\n\r\n    def update(self, preds, labels):\r\n        correct, labeled = batch_pix_accuracy(preds, labels)\r\n        inter, union = batch_intersection_union(preds, labels)\r\n        self.total_correct += correct  # 预测正确的像素数\r\n        self.total_label += labeled  # GT目标的像素数\r\n        self.total_inter += inter  # 交集\r\n        self.total_union += union  # 并集\r\n\r\n    def get(self):\r\n        pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)\r\n        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)\r\n        mIoU = IoU.mean()\r\n        return float(pixAcc), mIoU\r\n\r\n    def reset(self):\r\n        self.total_inter = 0\r\n        self.total_union = 0\r\n        self.total_correct = 0\r\n        self.total_label = 0\r\n\r\n\r\nclass PD_FA():\r\n    def __init__(self, ):\r\n        super(PD_FA, self).__init__()\r\n        self.image_area_total = []\r\n        self.image_area_match = []\r\n        self.dismatch_pixel = 0\r\n        self.all_pixel = 0\r\n        self.PD = 0\r\n        self.target = 0\r\n\r\n    def update(self, preds, labels, size):\r\n        predits = np.array((preds).cpu()).astype('int64')\r\n        labelss = np.array((labels).cpu()).astype('int64')\r\n\r\n        image = measure.label(predits, connectivity=2)\r\n        coord_image = measure.regionprops(image)\r\n        label = measure.label(labelss, connectivity=2)\r\n        coord_label = measure.regionprops(label)\r\n\r\n        self.target += len(coord_label)   # 目标总数  直接就搞GT的连通域个数\r\n        self.image_area_total = []   # 图像中预测的区域列表\r\n        self.image_area_match = []\r\n        self.distance_match = []\r\n        self.dismatch = []\r\n\r\n        for K in range(len(coord_image)):\r\n            area_image = np.array(coord_image[K].area)\r\n            self.image_area_total.append(area_image)\r\n\r\n        for i in range(len(coord_label)):   # image 与 label 之间 根据中心点 进行连通域的确定\r\n            centroid_label = np.array(list(coord_label[i].centroid))\r\n            for m in range(len(coord_image)):\r\n                centroid_image = np.array(list(coord_image[m].centroid))\r\n                distance = np.linalg.norm(centroid_image - centroid_label)\r\n                area_image = np.array(coord_image[m].area)\r\n                if distance < 3:\r\n                    self.distance_match.append(distance)\r\n                    self.image_area_match.append(area_image)\r\n\r\n                    del coord_image[m]   # 匹配上一个之后就 清除一个\r\n                    break\r\n\r\n        self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] # 在image里面 但是不在label里面\r\n        self.dismatch_pixel += np.sum(self.dismatch)  # Fa 虚警\r\n        self.all_pixel += size[0] * size[1]\r\n        self.PD += len(self.distance_match)  # 如果中心点之间距离在3一下 就算Pd  所以Pd 是匹配上了的目标的个数\r\n\r\n    def get(self):\r\n        Final_FA = self.dismatch_pixel / self.all_pixel\r\n        Final_PD = self.PD / self.target\r\n        return Final_PD, float(Final_FA.cpu().detach().numpy())\r\n\r\n    def reset(self):\r\n        self.FA = np.zeros([self.bins + 1])\r\n        self.PD = np.zeros([self.bins + 1])\r\n\r\n\r\ndef batch_pix_accuracy(output, target):\r\n    if len(target.shape) == 3:\r\n        target = np.expand_dims(target.float(), axis=1)\r\n    elif len(target.shape) == 4:\r\n        target = target.float()\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n\r\n    assert output.shape == target.shape, \"Predict and Label Shape Don't Match\"\r\n    predict = (output > 0).float()  # 将output 从 True Flase 转成 1 0\r\n    pixel_labeled = (target > 0).float().sum()  # GF中 1的个数\r\n    pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum()  # 预测对的个数\r\n    assert pixel_correct <= pixel_labeled, \"Correct area should be smaller than Labeled\"\r\n    return pixel_correct, pixel_labeled\r\n\r\n\r\ndef batch_intersection_union(output, target):\r\n    mini = 1\r\n    maxi = 1\r\n    nbins = 1\r\n    predict = (output > 0).float()\r\n    if len(target.shape) == 3:\r\n        target = np.expand_dims(target.float(), axis=1)\r\n    elif len(target.shape) == 4:\r\n        target = target.float()\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n    intersection = predict * ((predict == target).float())\r\n\r\n    area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_union = area_pred + area_lab - area_inter\r\n\r\n    assert (area_inter <= area_union).all(), \\\r\n        \"Error: Intersection area should be smaller than Union area\"\r\n    return area_inter, area_union\r\n"
  },
  {
    "path": "model/Config.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author  : Shuai Yuan\n# @File    : Config.py\n# @Software: PyCharm\n# coding=utf-8\nimport os\nimport torch\nimport time\nimport ml_collections\n\n\n\n\n##########################################################################\n# SCTrans configs\n##########################################################################\ndef get_SCTrans_config():\n    config = ml_collections.ConfigDict()\n    config.transformer = ml_collections.ConfigDict()\n    config.KV_size = 480  # KV_size = Q1 + Q2 + Q3 + Q4\n    config.transformer.num_heads = 4\n    config.transformer.num_layers = 4\n    config.patch_sizes = [16, 8, 4, 2]\n    config.base_channel = 32  # base channel of U-Net\n    config.n_classes = 1\n\n    # ********** unused **********\n    config.transformer.embeddings_dropout_rate = 0.1\n    config.transformer.attention_dropout_rate = 0.1\n    config.transformer.dropout_rate = 0\n    return config\n"
  },
  {
    "path": "model/SCTransNet.py",
    "content": "# -*- coding: utf-8 -*-\n# -*- coding: utf-8 -*-\n# @Author  : Shuai Yuan\n# @File    : SCTransNet.py\n# @Software: PyCharm\n# coding=utf-8\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport copy\nimport math\nfrom torch.nn import Dropout, Softmax, Conv2d, LayerNorm\nfrom torch.nn.modules.utils import _pair\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport ml_collections\nfrom einops import rearrange\nimport numbers\nfrom thop import profile\n\ndef get_CTranS_config():\n    config = ml_collections.ConfigDict()\n    config.transformer = ml_collections.ConfigDict()\n    config.KV_size = 480  # KV_size = Q1 + Q2 + Q3 + Q4\n    config.transformer.num_heads = 4\n    config.transformer.num_layers = 4\n    config.patch_sizes = [16, 8, 4, 2]\n    config.base_channel = 32  # base channel of U-Net\n    config.n_classes = 1\n\n    # ********** useless **********\n    config.transformer.embeddings_dropout_rate = 0.1\n    config.transformer.attention_dropout_rate = 0.1\n    config.transformer.dropout_rate = 0\n    return config\n\n\nclass Channel_Embeddings(nn.Module):\n    def __init__(self, config, patchsize, img_size, in_channels):\n        super().__init__()\n        img_size = _pair(img_size)\n        patch_size = _pair(patchsize)\n        n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])  # 14 * 14 = 196\n\n        self.patch_embeddings = Conv2d(in_channels=in_channels,\n                                       out_channels=in_channels,\n                                       kernel_size=patch_size,\n                                       stride=patch_size)\n        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))\n        self.dropout = Dropout(config.transformer[\"embeddings_dropout_rate\"])\n\n    def forward(self, x):\n        if x is None:\n            return None\n        x = self.patch_embeddings(x)\n        return x\n\n\nclass Reconstruct(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, scale_factor):\n        super(Reconstruct, self).__init__()\n        if kernel_size == 3:\n            padding = 1\n        else:\n            padding = 0\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)\n        self.norm = nn.BatchNorm2d(out_channels)\n        self.activation = nn.ReLU(inplace=True)\n        self.scale_factor = scale_factor\n\n    # def forward(self, x, h, w):\n    def forward(self, x):\n        if x is None:\n            return None\n\n        x = nn.Upsample(scale_factor=self.scale_factor, mode='bilinear')(x)\n\n        out = self.conv(x)\n        out = self.norm(out)\n        out = self.activation(out)\n        return out\n\n\n# spatial-embedded Single-head Channel-cross Attention (SSCA)\nclass Attention_org(nn.Module):\n    def __init__(self, config, vis, channel_num):\n        super(Attention_org, self).__init__()\n        self.vis = vis\n        self.KV_size = config.KV_size\n        self.channel_num = channel_num\n        self.num_attention_heads = 1\n        self.psi = nn.InstanceNorm2d(self.num_attention_heads)\n        self.softmax = Softmax(dim=3)\n\n        # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))\n        self.mhead1 = nn.Conv2d(channel_num[0], channel_num[0] * self.num_attention_heads, kernel_size=1, bias=False)\n        self.mhead2 = nn.Conv2d(channel_num[1], channel_num[1] * self.num_attention_heads, kernel_size=1, bias=False)\n        self.mhead3 = nn.Conv2d(channel_num[2], channel_num[2] * self.num_attention_heads, kernel_size=1, bias=False)\n        self.mhead4 = nn.Conv2d(channel_num[3], channel_num[3] * self.num_attention_heads, kernel_size=1, bias=False)\n        self.mheadk = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False)\n        self.mheadv = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False)\n\n        self.q1 = nn.Conv2d(channel_num[0] * self.num_attention_heads, channel_num[0] * self.num_attention_heads, kernel_size=3, stride=1,\n                            padding=1,\n                            groups=channel_num[0] * self.num_attention_heads // 2, bias=False)\n        self.q2 = nn.Conv2d(channel_num[1] * self.num_attention_heads, channel_num[1] * self.num_attention_heads, kernel_size=3, stride=1,\n                            padding=1,\n                            groups=channel_num[1] * self.num_attention_heads // 2, bias=False)\n        self.q3 = nn.Conv2d(channel_num[2] * self.num_attention_heads, channel_num[2] * self.num_attention_heads, kernel_size=3, stride=1,\n                            padding=1,\n                            groups=channel_num[2] * self.num_attention_heads // 2, bias=False)\n        self.q4 = nn.Conv2d(channel_num[3] * self.num_attention_heads, channel_num[3] * self.num_attention_heads, kernel_size=3, stride=1,\n                            padding=1,\n                            groups=channel_num[3] * self.num_attention_heads // 2, bias=False)\n        self.k = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1,\n                           padding=1, groups=self.KV_size * self.num_attention_heads, bias=False)\n        self.v = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1,\n                           padding=1, groups=self.KV_size * self.num_attention_heads, bias=False)\n\n        self.project_out1 = nn.Conv2d(channel_num[0], channel_num[0], kernel_size=1, bias=False)\n        self.project_out2 = nn.Conv2d(channel_num[1], channel_num[1], kernel_size=1, bias=False)\n        self.project_out3 = nn.Conv2d(channel_num[2], channel_num[2], kernel_size=1, bias=False)\n        self.project_out4 = nn.Conv2d(channel_num[3], channel_num[3], kernel_size=1, bias=False)\n\n\n        # ****************** useless ***************************************\n        self.q1_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q1_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q1_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q1_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n\n        self.q2_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q2_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q2_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q2_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n\n        self.q3_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q3_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q3_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q3_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n\n        self.q4_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q4_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q4_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n        self.q4_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)\n\n    def forward(self, emb1, emb2, emb3, emb4, emb_all):\n        b, c, h, w = emb1.shape\n        q1 = self.q1(self.mhead1(emb1))\n        q2 = self.q2(self.mhead2(emb2))\n        q3 = self.q3(self.mhead3(emb3))\n        q4 = self.q4(self.mhead4(emb4))\n        k = self.k(self.mheadk(emb_all))\n        v = self.v(self.mheadv(emb_all))\n        # k, v = kv.chunk(2, dim=1)\n\n        q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n        q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n        q3 = rearrange(q3, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n        q4 = rearrange(q4, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)\n\n        q1 = torch.nn.functional.normalize(q1, dim=-1)\n        q2 = torch.nn.functional.normalize(q2, dim=-1)\n        q3 = torch.nn.functional.normalize(q3, dim=-1)\n        q4 = torch.nn.functional.normalize(q4, dim=-1)\n        k = torch.nn.functional.normalize(k, dim=-1)\n\n        _, _, c1, _ = q1.shape\n        _, _, c2, _ = q2.shape\n        _, _, c3, _ = q3.shape\n        _, _, c4, _ = q4.shape\n        _, _, c, _ = k.shape\n\n        attn1 = (q1 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)\n        attn2 = (q2 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)\n        attn3 = (q3 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)\n        attn4 = (q4 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)\n\n        attention_probs1 = self.softmax(self.psi(attn1))\n        attention_probs2 = self.softmax(self.psi(attn2))\n        attention_probs3 = self.softmax(self.psi(attn3))\n        attention_probs4 = self.softmax(self.psi(attn4))\n\n        out1 = (attention_probs1 @ v)\n        out2 = (attention_probs2 @ v)\n        out3 = (attention_probs3 @ v)\n        out4 = (attention_probs4 @ v)\n\n        out_1 = out1.mean(dim=1)\n        out_2 = out2.mean(dim=1)\n        out_3 = out3.mean(dim=1)\n        out_4 = out4.mean(dim=1)\n\n        out_1 = rearrange(out_1, 'b  c (h w) -> b c h w', h=h, w=w)\n        out_2 = rearrange(out_2, 'b  c (h w) -> b c h w', h=h, w=w)\n        out_3 = rearrange(out_3, 'b  c (h w) -> b c h w', h=h, w=w)\n        out_4 = rearrange(out_4, 'b  c (h w) -> b c h w', h=h, w=w)\n\n        O1 = self.project_out1(out_1)\n        O2 = self.project_out2(out_2)\n        O3 = self.project_out3(out_3)\n        O4 = self.project_out4(out_4)\n        weights = None\n\n        return O1, O2, O3, O4, weights\n\n\ndef to_3d(x):\n    return rearrange(x, 'b c h w -> b (h w) c')\n\n\ndef to_4d(x, h, w):\n    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)\n\n\nclass BiasFree_LayerNorm(nn.Module):\n    def __init__(self, normalized_shape):\n        super(BiasFree_LayerNorm, self).__init__()\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        normalized_shape = torch.Size(normalized_shape)\n\n        assert len(normalized_shape) == 1\n\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.normalized_shape = normalized_shape\n\n    def forward(self, x):\n        sigma = x.var(-1, keepdim=True, unbiased=False)\n        return x / torch.sqrt(sigma + 1e-5) * self.weight\n\n\nclass WithBias_LayerNorm(nn.Module):\n    def __init__(self, normalized_shape):\n        super(WithBias_LayerNorm, self).__init__()\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        normalized_shape = torch.Size(normalized_shape)\n\n        assert len(normalized_shape) == 1\n\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.normalized_shape = normalized_shape\n\n    def forward(self, x):\n        mu = x.mean(-1, keepdim=True)\n        sigma = x.var(-1, keepdim=True, unbiased=False)\n        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias\n\n\nclass LayerNorm3d(nn.Module):\n    def __init__(self, dim, LayerNorm_type):\n        super(LayerNorm3d, self).__init__()\n        if LayerNorm_type == 'BiasFree':\n            self.body = BiasFree_LayerNorm(dim)\n        else:\n            self.body = WithBias_LayerNorm(dim)\n\n    def forward(self, x):\n        h, w = x.shape[-2:]\n        return to_4d(self.body(to_3d(x)), h, w)\n\nclass eca_layer_2d(nn.Module):\n    def __init__(self, channel, k_size=3):\n        super(eca_layer_2d, self).__init__()\n        padding = k_size // 2\n        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)\n        self.conv = nn.Sequential(\n            nn.Conv1d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False),\n            nn.Sigmoid()\n        )\n        self.channel = channel\n        self.k_size = k_size\n\n    def forward(self, x):\n        out = self.avg_pool(x)\n        out = out.view(x.size(0), 1, x.size(1))\n        out = self.conv(out)\n        out = out.view(x.size(0), x.size(1), 1, 1)\n        return out * x\n\n# Complementary Feed-forward Network (CFN)\nclass FeedForward(nn.Module):\n    def __init__(self, dim, ffn_expansion_factor, bias):\n        super(FeedForward, self).__init__()\n\n        hidden_features = int(dim * ffn_expansion_factor)\n\n        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)\n\n        self.dwconv3x3 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features,\n                                   bias=bias)\n        self.dwconv5x5 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, groups=hidden_features,\n                                   bias=bias)\n        self.relu3 = nn.ReLU()\n        self.relu5 = nn.ReLU()\n        self.project_out = nn.Conv2d(hidden_features * 2, dim, kernel_size=1, bias=bias)\n        self.eca = eca_layer_2d(dim)\n\n    def forward(self, x):\n        x_3,x_5 = self.project_in(x).chunk(2, dim=1)\n        x1_3 = self.relu3(self.dwconv3x3(x_3))\n        x1_5 = self.relu5(self.dwconv5x5(x_5))\n        x = torch.cat([x1_3, x1_5], dim=1)\n        x = self.project_out(x)\n        x = self.eca(x)\n        return x\n\n\n#  Spatial-channel Cross Transformer Block (SCTB)\nclass Block_ViT(nn.Module):\n    def __init__(self, config, vis, channel_num):\n        super(Block_ViT, self).__init__()\n        self.attn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')\n        self.attn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')\n        self.attn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')\n        self.attn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')\n        self.attn_norm = LayerNorm3d(config.KV_size, LayerNorm_type='WithBias')\n\n        self.channel_attn = Attention_org(config, vis, channel_num)\n\n        self.ffn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')\n        self.ffn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')\n        self.ffn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')\n        self.ffn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')\n\n        self.ffn1 = FeedForward(channel_num[0], ffn_expansion_factor=2.66, bias=False)\n        self.ffn2 = FeedForward(channel_num[1], ffn_expansion_factor=2.66, bias=False)\n        self.ffn3 = FeedForward(channel_num[2], ffn_expansion_factor=2.66, bias=False)\n        self.ffn4 = FeedForward(channel_num[3], ffn_expansion_factor=2.66, bias=False)\n\n\n    def forward(self, emb1, emb2, emb3, emb4):\n        embcat = []\n        org1 = emb1\n        org2 = emb2\n        org3 = emb3\n        org4 = emb4\n        for i in range(4):\n            var_name = \"emb\" + str(i + 1)\n            tmp_var = locals()[var_name]\n            if tmp_var is not None:\n                embcat.append(tmp_var)\n        emb_all = torch.cat(embcat, dim=1)\n        cx1 = self.attn_norm1(emb1) if emb1 is not None else None\n        cx2 = self.attn_norm2(emb2) if emb2 is not None else None\n        cx3 = self.attn_norm3(emb3) if emb3 is not None else None\n        cx4 = self.attn_norm4(emb4) if emb4 is not None else None\n        emb_all = self.attn_norm(emb_all)  # 1 196 960\n        cx1, cx2, cx3, cx4, weights = self.channel_attn(cx1, cx2, cx3, cx4, emb_all)\n        cx1 = org1 + cx1 if emb1 is not None else None\n        cx2 = org2 + cx2 if emb2 is not None else None\n        cx3 = org3 + cx3 if emb3 is not None else None\n        cx4 = org4 + cx4 if emb4 is not None else None\n\n        org1 = cx1\n        org2 = cx2\n        org3 = cx3\n        org4 = cx4\n        x1 = self.ffn_norm1(cx1) if emb1 is not None else None\n        x2 = self.ffn_norm2(cx2) if emb2 is not None else None\n        x3 = self.ffn_norm3(cx3) if emb3 is not None else None\n        x4 = self.ffn_norm4(cx4) if emb4 is not None else None\n        x1 = self.ffn1(x1) if emb1 is not None else None\n        x2 = self.ffn2(x2) if emb2 is not None else None\n        x3 = self.ffn3(x3) if emb3 is not None else None\n        x4 = self.ffn4(x4) if emb4 is not None else None\n        x1 = x1 + org1 if emb1 is not None else None\n        x2 = x2 + org2 if emb2 is not None else None\n        x3 = x3 + org3 if emb3 is not None else None\n        x4 = x4 + org4 if emb4 is not None else None\n\n        return x1, x2, x3, x4, weights\n\n\nclass Encoder(nn.Module):\n    def __init__(self, config, vis, channel_num):\n        super(Encoder, self).__init__()\n        self.vis = vis\n        self.layer = nn.ModuleList()\n        self.encoder_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')\n        self.encoder_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')\n        self.encoder_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')\n        self.encoder_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')\n        for _ in range(config.transformer[\"num_layers\"]):\n            layer = Block_ViT(config, vis, channel_num)\n            self.layer.append(copy.deepcopy(layer))\n\n    def forward(self, emb1, emb2, emb3, emb4):\n        attn_weights = []\n        for layer_block in self.layer:\n            emb1, emb2, emb3, emb4, weights = layer_block(emb1, emb2, emb3, emb4)\n            if self.vis:\n                attn_weights.append(weights)\n        emb1 = self.encoder_norm1(emb1) if emb1 is not None else None\n        emb2 = self.encoder_norm2(emb2) if emb2 is not None else None\n        emb3 = self.encoder_norm3(emb3) if emb3 is not None else None\n        emb4 = self.encoder_norm4(emb4) if emb4 is not None else None\n        return emb1, emb2, emb3, emb4, attn_weights\n\n\nclass ChannelTransformer(nn.Module):\n    def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):\n        super().__init__()\n\n        self.patchSize_1 = patchSize[0]\n        self.patchSize_2 = patchSize[1]\n        self.patchSize_3 = patchSize[2]\n        self.patchSize_4 = patchSize[3]\n        self.embeddings_1 = Channel_Embeddings(config, self.patchSize_1, img_size=img_size, in_channels=channel_num[0])\n        self.embeddings_2 = Channel_Embeddings(config, self.patchSize_2, img_size=img_size // 2, in_channels=channel_num[1])\n        self.embeddings_3 = Channel_Embeddings(config, self.patchSize_3, img_size=img_size // 4, in_channels=channel_num[2])\n        self.embeddings_4 = Channel_Embeddings(config, self.patchSize_4, img_size=img_size // 8, in_channels=channel_num[3])\n        self.encoder = Encoder(config, vis, channel_num)\n\n        self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1, scale_factor=(self.patchSize_1, self.patchSize_1))\n        self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1, scale_factor=(self.patchSize_2, self.patchSize_2))\n        self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1, scale_factor=(self.patchSize_3, self.patchSize_3))\n        self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1, scale_factor=(self.patchSize_4, self.patchSize_4))\n\n    def forward(self, en1, en2, en3, en4):\n        emb1 = self.embeddings_1(en1)\n        emb2 = self.embeddings_2(en2)\n        emb3 = self.embeddings_3(en3)\n        emb4 = self.embeddings_4(en4)\n\n        encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1, emb2, emb3, emb4)  # (B, n_patch, hidden)\n\n        x1 = self.reconstruct_1(encoded1) if en1 is not None else None\n        x2 = self.reconstruct_2(encoded2) if en2 is not None else None\n        x3 = self.reconstruct_3(encoded3) if en3 is not None else None\n        x4 = self.reconstruct_4(encoded4) if en4 is not None else None\n\n        x1 = x1 + en1 if en1 is not None else None\n        x2 = x2 + en2 if en2 is not None else None\n        x3 = x3 + en3 if en3 is not None else None\n        x4 = x4 + en4 if en4 is not None else None\n\n        return x1, x2, x3, x4, attn_weights\n\n\ndef get_activation(activation_type):\n    activation_type = activation_type.lower()\n    if hasattr(nn, activation_type):\n        return getattr(nn, activation_type)()\n    else:\n        return nn.ReLU()\n\n\ndef _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):\n    layers = []\n    layers.append(CBN(in_channels, out_channels, activation))\n\n    for _ in range(nb_Conv - 1):\n        layers.append(CBN(out_channels, out_channels, activation))\n    return nn.Sequential(*layers)\n\n\nclass CBN(nn.Module):\n    def __init__(self, in_channels, out_channels, activation='ReLU'):\n        super(CBN, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels,\n                              kernel_size=3, padding=1)\n        self.norm = nn.BatchNorm2d(out_channels)\n        self.activation = get_activation(activation)\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        return self.activation(out)\n\n\nclass DownBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):\n        super(DownBlock, self).__init__()\n        self.maxpool = nn.MaxPool2d(2)\n        self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)\n\n    def forward(self, x):\n        out = self.maxpool(x)\n        return self.nConvs(out)\n\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\n\nclass CCA(nn.Module):\n    def __init__(self, F_g, F_x):\n        super().__init__()\n        self.mlp_x = nn.Sequential(\n            Flatten(),\n            nn.Linear(F_x, F_x))\n        self.mlp_g = nn.Sequential(\n            Flatten(),\n            nn.Linear(F_g, F_x))\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, g, x):\n        avg_pool_x = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        channel_att_x = self.mlp_x(avg_pool_x)\n        avg_pool_g = F.avg_pool2d(g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))\n        channel_att_g = self.mlp_g(avg_pool_g)\n        channel_att_sum = (channel_att_x + channel_att_g) / 2.0\n        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n        x_after_channel = x * scale\n        out = self.relu(x_after_channel)\n        return out\n\n\nclass UpBlock_attention(nn.Module):\n    def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):\n        super().__init__()\n        self.up = nn.Upsample(scale_factor=2)\n        self.coatt = CCA(F_g=in_channels // 2, F_x=in_channels // 2)\n        self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)\n\n    def forward(self, x, skip_x):\n        up = self.up(x)\n        skip_x_att = self.coatt(g=up, x=skip_x)\n        x = torch.cat([skip_x_att, up], dim=1)  # dim 1 is the channel dimension\n        return self.nConvs(x)\n\n\nclass Res_block(nn.Module):\n    def __init__(self, in_channels, out_channels, stride=1):\n        super(Res_block, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n        self.bn1 = nn.BatchNorm2d(out_channels)\n        self.relu = nn.LeakyReLU(inplace=True)\n        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n        self.bn2 = nn.BatchNorm2d(out_channels)\n        # self.fca = FCA_Layer(out_channels)\n        if stride != 1 or out_channels != in_channels:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),\n                nn.BatchNorm2d(out_channels))\n        else:\n            self.shortcut = None\n\n    def forward(self, x):\n        residual = x\n        if self.shortcut is not None:\n            residual = self.shortcut(x)\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        out += residual\n        out = self.relu(out)\n        return out\n\n\nclass SCTransNet(nn.Module):\n    def __init__(self, config, n_channels=1, n_classes=1, img_size=256, vis=False, mode='train', deepsuper=True):\n        super().__init__()\n        self.vis = vis\n        self.deepsuper = deepsuper\n        print('Deep-Supervision:', deepsuper)\n        self.mode = mode\n        self.n_channels = n_channels\n        self.n_classes = n_classes\n        in_channels = config.base_channel  # basic channel 64\n        block = Res_block\n        self.pool = nn.MaxPool2d(2, 2)\n        self.inc = self._make_layer(block, n_channels, in_channels)\n        self.down_encoder1 = self._make_layer(block, in_channels, in_channels * 2, 1)  # 64  128\n        self.down_encoder2 = self._make_layer(block, in_channels * 2, in_channels * 4, 1)  # 64  128\n        self.down_encoder3 = self._make_layer(block, in_channels * 4, in_channels * 8, 1)  # 64  128\n        self.down_encoder4 = self._make_layer(block, in_channels * 8, in_channels * 8, 1)  # 64  128\n        self.mtc = ChannelTransformer(config, vis, img_size,\n                                      channel_num=[in_channels, in_channels * 2, in_channels * 4, in_channels * 8],\n                                      patchSize=config.patch_sizes)\n        self.up_decoder4 = UpBlock_attention(in_channels * 16, in_channels * 4, nb_Conv=2)\n        self.up_decoder3 = UpBlock_attention(in_channels * 8, in_channels * 2, nb_Conv=2)\n        self.up_decoder2 = UpBlock_attention(in_channels * 4, in_channels, nb_Conv=2)\n        self.up_decoder1 = UpBlock_attention(in_channels * 2, in_channels, nb_Conv=2)\n        self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1, 1), stride=(1, 1))\n\n        if self.deepsuper:\n            self.gt_conv5 = nn.Sequential(nn.Conv2d(in_channels * 8, 1, 1))\n            self.gt_conv4 = nn.Sequential(nn.Conv2d(in_channels * 4, 1, 1))\n            self.gt_conv3 = nn.Sequential(nn.Conv2d(in_channels * 2, 1, 1))\n            self.gt_conv2 = nn.Sequential(nn.Conv2d(in_channels * 1, 1, 1))\n            self.outconv = nn.Conv2d(5 * 1, 1, 1)\n\n    def _make_layer(self, block, input_channels, output_channels, num_blocks=1):\n        layers = []\n        layers.append(block(input_channels, output_channels))\n        for i in range(num_blocks - 1):\n            layers.append(block(output_channels, output_channels))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x1 = self.inc(x)  # 64 224 224\n        x2 = self.down_encoder1(self.pool(x1))  # 128 112 112\n        x3 = self.down_encoder2(self.pool(x2))  # 256 56  56\n        x4 = self.down_encoder3(self.pool(x3))  # 512 28  28\n        d5 = self.down_encoder4(self.pool(x4))  # 512 14  14\n        #  CCT\n        f1 = x1\n        f2 = x2\n        f3 = x3\n        f4 = x4\n        #  CCT\n        x1, x2, x3, x4, att_weights = self.mtc(x1, x2, x3, x4)\n        x1 = x1 + f1\n        x2 = x2 + f2\n        x3 = x3 + f3\n        x4 = x4 + f4\n        #  Feature fusion\n        d4 = self.up_decoder4(d5, x4)\n        d3 = self.up_decoder3(d4, x3)\n        d2 = self.up_decoder2(d3, x2)\n        out = self.outc(self.up_decoder1(d2, x1))\n        # deep supervision\n        if self.deepsuper:\n            gt_5 = self.gt_conv5(d5)\n            gt_4 = self.gt_conv4(d4)\n            gt_3 = self.gt_conv3(d3)\n            gt_2 = self.gt_conv2(d2)\n            # 原始深监督\n            gt5 = F.interpolate(gt_5, scale_factor=16, mode='bilinear', align_corners=True)\n            gt4 = F.interpolate(gt_4, scale_factor=8, mode='bilinear', align_corners=True)\n            gt3 = F.interpolate(gt_3, scale_factor=4, mode='bilinear', align_corners=True)\n            gt2 = F.interpolate(gt_2, scale_factor=2, mode='bilinear', align_corners=True)\n            d0 = self.outconv(torch.cat((gt2, gt3, gt4, gt5, out), 1))\n\n            if self.mode == 'train':\n                return (torch.sigmoid(gt5), torch.sigmoid(gt4), torch.sigmoid(gt3), torch.sigmoid(gt2), torch.sigmoid(d0), torch.sigmoid(out))\n            else:\n                return torch.sigmoid(out)\n        else:\n            return torch.sigmoid(out)\n\n\nif __name__ == '__main__':\n    config_vit = get_CTranS_config()\n    model = SCTransNet(config_vit, mode='train', deepsuper=True)\n    model = model\n    inputs = torch.rand(1, 1, 256, 256)\n    output = model(inputs)\n    flops, params = profile(model, (inputs,))\n\n    print(\"-\" * 50)\n    print('FLOPs = ' + str(flops / 1000 ** 3) + ' G')\n    print('Params = ' + str(params / 1000 ** 2) + ' M')\n"
  },
  {
    "path": "test.py",
    "content": "import argparse\r\nfrom torch.autograd import Variable\r\nfrom torch.utils.data import DataLoader\r\nfrom tqdm import tqdm\r\nimport threading\r\nfrom dataset import *\r\nimport time\r\nfrom collections import OrderedDict\r\nfrom model.SCTransNet import SCTransNet as SCTransNet\r\n# from loss import *\r\nimport model.Config as config\r\nimport numpy as np\r\nimport torch\r\nfrom skimage import measure\r\n\r\n\r\ndef cal_tp_pos_fp_neg(output, target, nclass, score_thresh):\r\n    predict = (output > score_thresh).float()\r\n    if len(target.shape) == 3:\r\n        print('？？？？')  # 加一个维度 使得target与 output的size一致\r\n        target = target.unsqueeze(dim=0)\r\n        # target = np.expand_dims(target.float(), axis=1)\r\n        target.to('cuda', torch.float)\r\n\r\n    elif len(target.shape) == 4:\r\n        target = target.float()\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n    # 现在predict中高于阈值的部分为全1矩阵   target是GT\r\n\r\n    intersection = predict * ((predict == target).float())\r\n\r\n    tp = intersection.sum()  # 对的预测为对的\r\n    fp = (predict * ((predict != target).float())).sum()  # 错的预测为对的 虚警像素数\r\n    tn = ((1 - predict) * ((predict == target).float())).sum()  # 错的预测为错的\r\n    fn = (((predict != target).float()) * (1 - predict)).sum()  # 对的预测为错的\r\n    pos = tp + fn  # 标签中 阳性的个数\r\n    neg = fp + tn  # 标签中 阴性的个数\r\n    class_pos = tp + fp  # 检测出的个数\r\n\r\n    return tp, pos, fp, neg, class_pos\r\n\r\n\r\nclass SamplewiseSigmoidMetric(object):\r\n    \"\"\"Computes pixAcc and mIoU metric scores\r\n    \"\"\"\r\n\r\n    def __init__(self, nclass, score_thresh=0.5):\r\n        self.nclass = nclass\r\n        self.score_thresh = score_thresh\r\n        self.lock = threading.Lock()\r\n        self.reset()\r\n\r\n    def update(self, preds, labels):\r\n        \"\"\"Updates the internal evaluation result.\r\n\r\n        Parameters\r\n        ----------\r\n        labels : 'NDArray' or list of `NDArray`\r\n            The labels of the data.\r\n\r\n        preds : 'NDArray' or list of `NDArray`\r\n            Predicted values.\r\n        \"\"\"\r\n\r\n        def evaluate_worker(self, label, pred):\r\n            inter_arr, union_arr = batch_intersection_union_n(\r\n                pred, label, self.nclass, self.score_thresh)\r\n            with self.lock:\r\n                self.total_inter = np.append(self.total_inter, inter_arr)\r\n                self.total_union = np.append(self.total_union, union_arr)\r\n\r\n        if isinstance(preds, torch.Tensor):\r\n            evaluate_worker(self, labels, preds)\r\n        elif isinstance(preds, (list, tuple)):\r\n            threads = [threading.Thread(target=evaluate_worker,\r\n                                        args=(self, label, pred),\r\n                                        )\r\n                       for (label, pred) in zip(labels, preds)]\r\n            for thread in threads:\r\n                thread.start()\r\n            for thread in threads:\r\n                thread.join()\r\n\r\n    def get(self):\r\n        \"\"\"Gets the current evaluation result.\r\n\r\n        Returns\r\n        -------\r\n        metrics : tuple of float\r\n            pixAcc and mIoU\r\n        \"\"\"\r\n        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)\r\n        nIoU = IoU.mean()\r\n        return nIoU\r\n\r\n    def reset(self):\r\n        \"\"\"Resets the internal evaluation result to initial state.\"\"\"\r\n        self.total_inter = np.array([])\r\n        self.total_union = np.array([])\r\n        self.total_correct = np.array([])\r\n        self.total_label = np.array([])\r\n\r\n\r\ndef batch_intersection_union_n(output, target, nclass, score_thresh):\r\n    \"\"\"nIoU\"\"\"\r\n    mini = 1\r\n    maxi = 1  # nclass\r\n    nbins = 1  # nclass\r\n    outputnp = output.detach().cpu().numpy()\r\n    # outputsig = F.sigmoid(output).detach().cpu().numpy()\r\n    # outputsig = nd.sigmoid(output).asnumpy()\r\n    predict = (outputnp > 0.5).astype('int64')\r\n    # predict = predict.detach().cpu().numpy()\r\n    # predict = (output.asnumpy() > 0).astype('int64') # P\r\n    if len(target.shape) == 3:\r\n        target = nd.expand_dims(target, axis=1).asnumpy().astype('int64')  # T\r\n    elif len(target.shape) == 4:\r\n        target = target.cpu().numpy().astype('int64')  # T\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n    intersection = predict * (predict == target)  # TP  交集\r\n\r\n    num_sample = intersection.shape[0]\r\n    area_inter_arr = np.zeros(num_sample)\r\n    area_pred_arr = np.zeros(num_sample)\r\n    area_lab_arr = np.zeros(num_sample)\r\n    area_union_arr = np.zeros(num_sample)\r\n    for b in range(num_sample):\r\n        # areas of intersection and union\r\n        area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi))\r\n        area_inter_arr[b] = area_inter\r\n\r\n        area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi))\r\n        area_pred_arr[b] = area_pred\r\n\r\n        area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi))\r\n        area_lab_arr[b] = area_lab\r\n\r\n        area_union = area_pred + area_lab - area_inter\r\n        area_union_arr[b] = area_union\r\n\r\n        assert (area_inter <= area_union).all(), \\\r\n            \"Intersection area should be smaller than Union area\"\r\n\r\n    return area_inter_arr, area_union_arr\r\n\r\n\r\nclass ROCMetric05():\r\n    \"\"\"Computes pixAcc and mIoU metric scores\r\n    \"\"\"\r\n\r\n    def __init__(self, nclass, bins):\r\n        # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值\r\n        # nclass :有几个类别 红外弱小目标检测只有一个类别\r\n        super(ROCMetric05, self).__init__()\r\n        self.nclass = nclass\r\n        self.bins = bins\r\n        self.tp_arr = np.zeros(self.bins + 1)\r\n        self.pos_arr = np.zeros(self.bins + 1)\r\n        self.fp_arr = np.zeros(self.bins + 1)\r\n        self.neg_arr = np.zeros(self.bins + 1)\r\n        self.class_pos = np.zeros(self.bins + 1)\r\n        # self.reset()\r\n\r\n    # 网络输入的结果和标签 计算两者之前的东西\r\n    def update(self, preds, labels):\r\n        for iBin in range(self.bins + 1):\r\n            # score_thresh = (iBin + 0.0) / self.bins\r\n            score_thresh = (0.0 + iBin) / self.bins\r\n            # print(iBin, \"-th, score_thresh: \", score_thresh)\r\n            i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)\r\n            self.tp_arr[iBin] += i_tp\r\n            self.pos_arr[iBin] += i_pos\r\n            self.fp_arr[iBin] += i_fp  # 虚警像素数\r\n            self.neg_arr[iBin] += i_neg\r\n            self.class_pos[iBin] += i_class_pos\r\n\r\n    def get(self):\r\n        tp_rates = self.tp_arr / (self.pos_arr + 0.001)  # tp_rates = recall = TP/(TP+FN)\r\n        fp_rates = self.fp_arr / (self.neg_arr + 0.001)  # fp_rates =  FP/(FP+TN)\r\n        FP = self.fp_arr / (self.neg_arr + self.pos_arr)\r\n        recall = self.tp_arr / (self.pos_arr + 0.001)  # recall = TP/(TP+FN)\r\n        precision = self.tp_arr / (self.class_pos + 0.001)  # precision = TP/(TP+FP)\r\n        f1_score = (2.0 * recall[5] * precision[5]) / (recall[5] + precision[5] + 0.00001)\r\n\r\n        return tp_rates, fp_rates, recall, precision, FP, f1_score\r\n\r\n    def reset(self):\r\n        self.tp_arr = np.zeros([11])\r\n        self.pos_arr = np.zeros([11])\r\n        self.fp_arr = np.zeros([11])\r\n        self.neg_arr = np.zeros([11])\r\n        self.class_pos = np.zeros([11])\r\n\r\n\r\nclass mIoU():\r\n\r\n    def __init__(self):\r\n        super(mIoU, self).__init__()\r\n        self.reset()\r\n\r\n    def update(self, preds, labels):\r\n        correct, labeled = batch_pix_accuracy(preds, labels)  # labeled: GT中目标的像素数目   correct:预测正确的像素数\r\n        inter, union = batch_intersection_union(preds, labels)\r\n        self.total_correct += correct\r\n        self.total_label += labeled\r\n        self.total_inter += inter\r\n        self.total_union += union\r\n\r\n    def get(self):\r\n        pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)\r\n        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)\r\n        mIoU = IoU.mean()\r\n        return float(pixAcc), mIoU\r\n\r\n    def reset(self):\r\n        self.total_inter = 0\r\n        self.total_union = 0\r\n        self.total_correct = 0\r\n        self.total_label = 0\r\n\r\n\r\nclass PDFA():\r\n    def __init__(self, ):\r\n        super(PDFA, self).__init__()\r\n        self.image_area_total = []\r\n        self.image_area_match = []\r\n        self.dismatch_pixel = 0\r\n        self.all_pixel = 0\r\n        self.PD = 0\r\n        self.target = 0\r\n\r\n    def update(self, preds, labels, size):\r\n        predits = np.array((preds).cpu()).astype('int64')\r\n        labelss = np.array((labels).cpu()).astype('int64')\r\n\r\n        image = measure.label(predits, connectivity=2)\r\n        coord_image = measure.regionprops(image)\r\n        label = measure.label(labelss, connectivity=2)\r\n        coord_label = measure.regionprops(label)\r\n\r\n        self.target += len(coord_label)\r\n        self.image_area_total = []\r\n        self.image_area_match = []\r\n        self.distance_match = []\r\n        self.dismatch = []\r\n\r\n        for K in range(len(coord_image)):\r\n            area_image = np.array(coord_image[K].area)\r\n            self.image_area_total.append(area_image)\r\n\r\n        for i in range(len(coord_label)):\r\n            centroid_label = np.array(list(coord_label[i].centroid))\r\n            for m in range(len(coord_image)):\r\n                centroid_image = np.array(list(coord_image[m].centroid))\r\n                distance = np.linalg.norm(centroid_image - centroid_label)\r\n                area_image = np.array(coord_image[m].area)\r\n                if distance < 3:\r\n                    self.distance_match.append(distance)\r\n                    self.image_area_match.append(area_image)\r\n\r\n                    del coord_image[m]\r\n                    break\r\n\r\n        self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]\r\n        self.dismatch_pixel += np.sum(self.dismatch)\r\n        self.all_pixel += size[0] * size[1]\r\n        self.PD += len(self.distance_match)\r\n\r\n    def get(self):\r\n        Final_FA = self.dismatch_pixel / self.all_pixel\r\n        Final_PD = self.PD / self.target\r\n        return Final_PD, float(Final_FA.cpu().detach().numpy())\r\n\r\n    def reset(self):\r\n        self.FA = np.zeros([self.bins + 1])\r\n        self.PD = np.zeros([self.bins + 1])\r\n\r\n\r\ndef batch_pix_accuracy(output, target):\r\n    if len(target.shape) == 3:\r\n        target = np.expand_dims(target.float(), axis=1)\r\n    elif len(target.shape) == 4:\r\n        target = target.float()\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n\r\n    assert output.shape == target.shape, \"Predict and Label Shape Don't Match\"\r\n    predict = (output > 0).float()\r\n    pixel_labeled = (target > 0).float().sum()\r\n    pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum()\r\n    assert pixel_correct <= pixel_labeled, \"Correct area should be smaller than Labeled\"\r\n    return pixel_correct, pixel_labeled\r\n\r\n\r\ndef batch_intersection_union(output, target):\r\n    mini = 1\r\n    maxi = 1\r\n    nbins = 1\r\n    predict = (output > 0).float()\r\n    if len(target.shape) == 3:\r\n        target = np.expand_dims(target.float(), axis=1)\r\n    elif len(target.shape) == 4:\r\n        target = target.float()\r\n    else:\r\n        raise ValueError(\"Unknown target dimension\")\r\n    intersection = predict * ((predict == target).float())\r\n\r\n    area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))\r\n    area_union = area_pred + area_lab - area_inter\r\n\r\n    assert (area_inter <= area_union).all(), \\\r\n        \"Error: Intersection area should be smaller than Union area\"\r\n    return area_inter, area_union\r\n\r\n\r\nclass PD_FA():\r\n    def __init__(self, ):\r\n        super(PD_FA, self).__init__()\r\n        self.image_area_total = []\r\n        self.image_area_match = []\r\n        self.dismatch_pixel = 0\r\n        self.all_pixel = 0\r\n        self.PD = 0\r\n        self.target = 0\r\n\r\n    def update(self, preds, labels, size):\r\n        predits = np.array((preds).cpu()).astype('int64')\r\n        labelss = np.array((labels).cpu()).astype('int64')\r\n\r\n        image = measure.label(predits, connectivity=2)\r\n        coord_image = measure.regionprops(image)\r\n        label = measure.label(labelss, connectivity=2)\r\n        coord_label = measure.regionprops(label)\r\n\r\n        self.target += len(coord_label)  # 目标总数  直接就搞GT的连通域个数\r\n        self.image_area_total = []  # 图像中预测的区域列表\r\n        self.image_area_match = []\r\n        self.distance_match = []\r\n        self.dismatch = []\r\n\r\n        for K in range(len(coord_image)):\r\n            area_image = np.array(coord_image[K].area)\r\n            self.image_area_total.append(area_image)\r\n\r\n        for i in range(len(coord_label)):  # image 与 label 之间 根据中心点 进行连通域的确定\r\n            centroid_label = np.array(list(coord_label[i].centroid))\r\n            for m in range(len(coord_image)):\r\n                centroid_image = np.array(list(coord_image[m].centroid))\r\n                distance = np.linalg.norm(centroid_image - centroid_label)\r\n                area_image = np.array(coord_image[m].area)\r\n                if distance < 3:\r\n                    self.distance_match.append(distance)\r\n                    self.image_area_match.append(area_image)\r\n\r\n                    del coord_image[m]  # 匹配上一个之后就 清除一个\r\n                    break\r\n\r\n        self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]  # 在image里面 但是不在label里面\r\n\r\n        self.dismatch_pixel += np.sum(self.dismatch)  # Fa 虚警个数 像素的虚警\r\n        # print(self.dismatch_pixel)\r\n        self.all_pixel += size[0] * size[1]\r\n        self.PD += len(self.distance_match)  # 如果中心点之间距离在3一下 就算Pd  所以Pd 是匹配上了的目标的个数\r\n\r\n    def get(self):\r\n        Final_FA = self.dismatch_pixel / self.all_pixel\r\n        Final_PD = self.PD / self.target\r\n        return Final_PD, float(Final_FA.cpu().detach().numpy())\r\n\r\n    def reset(self):\r\n        self.FA = np.zeros([self.bins + 1])\r\n        self.PD = np.zeros([self.bins + 1])\r\n\r\n\r\n\r\n\r\nos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\r\nparser = argparse.ArgumentParser(description=\"PyTorch BasicIRSTD test\")\r\nparser.add_argument('--ROC_thr', type=int, default=10, help='num')\r\nparser.add_argument(\"--model_names\", default=['SCTrans'], type=list,\r\n                    help=\"model_name: 'ACM', 'Ours01', 'DNANet', 'ISNet', 'ACMNet', 'Ours01', 'ISTDU-Net', 'U-Net', 'RISTDnet'\")\r\nparser.add_argument(\"--pth_dirs\", default=['SIRST3/SCTransNet_NUAA_NUDT_IRSTD1K.pth.tar'], type=list)\r\nparser.add_argument(\"--dataset_dir\", default=r'D:\\05TGARS\\Upload\\datasets', type=str, help=\"train_dataset_dir\")\r\nparser.add_argument(\"--dataset_names\", default=['NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K'], type=list,\r\n                    help=\"dataset_name: 'NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K', 'SIRST3', 'NUDT-SIRST-Sea'\")\r\nparser.add_argument(\"--img_norm_cfg\", default=None, type=dict,\r\n                    help=\"specific a img_norm_cfg, default=None (using img_norm_cfg values of each dataset)\")\r\nparser.add_argument(\"--save_img\", default=False, type=bool, help=\"save image of or not\")\r\nparser.add_argument(\"--save_img_dir\", type=str, default=r'D:\\SCI\\01_02_SCI\\Result/',\r\n                    help=\"path of saved image\")\r\nparser.add_argument(\"--save_log\", type=str, default=r'D:\\05TGARS\\Upload\\log/', help=\"path of saved .pth\")\r\nparser.add_argument(\"--threshold\", type=float, default=0.5)\r\n\r\nglobal opt\r\nopt = parser.parse_args()\r\n\r\n\r\ndef test():\r\n    test_set = TestSetLoader(opt.dataset_dir, opt.train_dataset_name, opt.test_dataset_name, opt.img_norm_cfg)\r\n    test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)\r\n    # *************************固定阈值**********************\r\n    # 计算mIOU  完全OK\r\n    IOU = mIoU()\r\n    # 计算nIOU 完全OK\r\n    nIoU_metric = SamplewiseSigmoidMetric(nclass=1, score_thresh=0)\r\n\r\n    # 计算PD_FA   完全OK\r\n    eval_05 = PD_FA()\r\n    ROC_05 = ROCMetric05(nclass=1, bins=10)\r\n    config_vit = config.get_SCTrans_config()\r\n\r\n\r\n    # CPU\r\n    net = SCTransNet(config_vit, mode='test', deepsuper=True)\r\n    state_dict = torch.load(opt.pth_dir, map_location='cpu')\r\n    # # CUDA\r\n    # net = SCTransNet(config_vit, mode='test', deepsuper=True).cuda()\r\n    # state_dict = torch.load(opt.pth_dir)\r\n\r\n    new_state_dict = OrderedDict()\r\n    #\r\n    for k, v in state_dict['state_dict'].items():\r\n        name = k[6:]  # remove `module.`，表面从第7个key值字符取到最后一个字符，正好去掉了module.\r\n        new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。\r\n    net.load_state_dict(new_state_dict)\r\n    net.eval()\r\n    tbar = tqdm(test_loader)\r\n    with torch.no_grad():\r\n        for idx_iter, (img, gt_mask, size, img_dir) in enumerate(tbar):\r\n            # img = Variable(img)\r\n\r\n            # CPU\r\n            pred = net.forward(img)\r\n            pred = pred[:, :, :size[0], :size[1]]\r\n            gt_mask = gt_mask[:, :, :size[0], :size[1]]\r\n\r\n            # # CUDA:\r\n            # pred = net.forward(img).cuda()\r\n            # pred = pred[:, :, :size[0], :size[1]].cuda()\r\n            # gt_mask = gt_mask[:, :, :size[0], :size[1]].cuda()\r\n\r\n            # Fix  threshold ##########################################################\r\n            # IOU\r\n            IOU.update((pred > 0.5), gt_mask)  # 像素\r\n            # nIOU\r\n            nIoU_metric.update(pred, gt_mask)  # 像素\r\n            eval_05.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size)  # 目标\r\n            ROC_05.update(pred, gt_mask)\r\n            # save img\r\n            if opt.save_img == True:\r\n                img_save = transforms.ToPILImage()((pred[0, 0, :, :]).cpu())\r\n                if not os.path.exists(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name):\r\n                    os.makedirs(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name)\r\n                img_save.save(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name + '/' + img_dir[0] + '.png')\r\n\r\n        # 0.5\r\n        # IOU OK Good！\r\n        pixAcc, mIOU = IOU.get()\r\n        # # nIOU OK Good！\r\n        nIoU = nIoU_metric.get()\r\n        # # Pd Fa\r\n        results2 = eval_05.get()\r\n        #\r\n        # # FP\r\n        ture_positive_rate, false_positive_rate, recall, precision, FP, F1_score = ROC_05.get()\r\n\r\n        print('pixAcc: %.4f| mIoU: %.4f | nIoU: %.4f | Pd: %.4f| Fa: %.4f |F1: %.4f'\r\n              % (pixAcc * 100, mIOU * 100, nIoU * 100, results2[0] * 100, results2[1] * 1e+6, F1_score * 100))\r\n\r\n\r\n\r\n\r\nif __name__ == '__main__':\r\n    opt.f = open(opt.save_log + 'test_' + (time.ctime()).replace(' ', '_').replace(':', '_') + '.txt', 'w')\r\n    if opt.pth_dirs == None:\r\n        for i in range(len(opt.model_names)):\r\n            opt.model_name = opt.model_names[i]\r\n            print(opt.model_name)\r\n            opt.f.write(opt.model_name + '_400.pth.tar' + '\\n')\r\n            for dataset_name in opt.dataset_names:\r\n                opt.dataset_name = dataset_name\r\n                opt.train_dataset_name = opt.dataset_name\r\n                opt.test_dataset_name = opt.dataset_name\r\n                print(dataset_name)\r\n                opt.f.write(opt.dataset_name + '\\n')\r\n                opt.pth_dir = opt.save_log + opt.dataset_name + '/' + opt.model_name + '_400.pth.tar'\r\n                test()\r\n            print('\\n')\r\n            opt.f.write('\\n')\r\n        opt.f.close()\r\n    else:\r\n        for model_name in opt.model_names:\r\n            for dataset_name in opt.dataset_names:\r\n                for pth_dir in opt.pth_dirs:\r\n                    # if dataset_name in pth_dir and model_name in pth_dir:\r\n                    opt.test_dataset_name = dataset_name\r\n                    opt.model_name = model_name\r\n                    opt.train_dataset_name = pth_dir.split('/')[0]\r\n                    print(pth_dir)\r\n                    opt.f.write(pth_dir)\r\n                    print(opt.test_dataset_name)\r\n                    opt.f.write(opt.test_dataset_name + '\\n')\r\n                    opt.pth_dir = opt.save_log + pth_dir\r\n                    test()\r\n                    print('\\n')\r\n                    opt.f.write('\\n')\r\n        opt.f.close()\r\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\r\nimport time\r\nimport os\r\nimport cv2\r\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\r\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\r\nfrom torch.autograd import Variable\r\nfrom torch.utils.data import DataLoader\r\nfrom dataset import *\r\nfrom metrics import *\r\nfrom utils import *\r\nimport model.Config as config\r\nfrom torch.utils.tensorboard import SummaryWriter\r\nfrom model.SCTransNet import SCTransNet as SCTransNet\r\n\r\nparser = argparse.ArgumentParser(description=\"PyTorch BasicIRSTD train\")\r\nparser.add_argument(\"--model_names\", default=['SCTransNet'], type=list, help=\"'ACM', 'ALCNet', 'DNANet', 'ISNet', 'UIUNet', 'RDIAN', 'RISTDnet'\")\r\nparser.add_argument(\"--dataset_names\", default=['SIRST3'], type=list)\r\n# SIRST3： NUAA NUDT IRSTD-1K\r\nparser.add_argument(\"--optimizer_name\", default='Adam', type=str, help=\"optimizer name: AdamW, Adam, Adagrad, SGD\")\r\nparser.add_argument(\"--epochs\", default=1000, type=int, help=\"optimizer name: AdamW, Adam, Adagrad, SGD\")\r\nparser.add_argument(\"--begin_test\", default=500, type=int)\r\nparser.add_argument(\"--every_test\", default=1, type=int)\r\nparser.add_argument(\"--every_save_pth\", default=1000, type=int)\r\nparser.add_argument(\"--every_print\", default=10, type=int)\r\nparser.add_argument(\"--dataset_dir\", default=r'./datasets')\r\nparser.add_argument(\"--batchSize\", type=int, default=16, help=\"Training batch sizse\")\r\nparser.add_argument(\"--patchSize\", type=int, default=256, help=\"Training patch size\")\r\nparser.add_argument(\"--save\", default=r'./log', type=str, help=\"Save path of checkpoints\")\r\nparser.add_argument(\"--log_dir\", type=str, default=\"./otherlogs/SCTransNet\", help='path of log files')\r\nparser.add_argument(\"--img_norm_cfg\", default=None, type=dict)\r\nparser.add_argument(\"--threads\", type=int, default=0, help=\"Number of threads for data loader to use\")\r\nparser.add_argument(\"--threshold\", type=float, default=0.5, help=\"Threshold for test\")\r\nparser.add_argument(\"--seed\", type=int, default=42, help=\"Threshold for test\")\r\nparser.add_argument(\"--resume\", default=False, type=list, help=\"Resume from exisiting checkpoints (default: None)\")\r\n\r\nglobal opt\r\nopt = parser.parse_args()\r\n\r\nseed_pytorch(opt.seed)\r\n\r\nconfig_vit = config.get_SCTrans_config()\r\n\r\n\r\ndef weights_init_kaiming(m):\r\n    classname = m.__class__.__name__\r\n    if classname.find('Conv') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\r\n    elif classname.find('Linear') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\r\n    elif classname.find('BatchNorm') != -1:\r\n        init.normal_(m.weight.data, 1.0, 0.02)\r\n        init.constant_(m.bias.data, 0.0)\r\n\r\n\r\ndef train():\r\n    train_set = TrainSetLoader(dataset_dir=opt.dataset_dir, dataset_name=opt.dataset_name, patch_size=opt.patchSize,\r\n                               img_norm_cfg=opt.img_norm_cfg)\r\n    train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)\r\n    net = Net(model_name=opt.model_name, mode='train').cuda()\r\n    net.apply(weights_init_kaiming)\r\n    net.train()\r\n\r\n    epoch_state = 0\r\n    total_loss_list = []\r\n    total_loss_epoch = []\r\n\r\n    if not os.path.exists(opt.log_dir):\r\n        os.makedirs(opt.log_dir)\r\n    writer = SummaryWriter(opt.log_dir)\r\n\r\n    if opt.resume:\r\n        # for resume_pth in opt.resume:\r\n        #     if opt.dataset_name in resume_pth and opt.model_name in resume_pth:\r\n        ckpt = torch.load('XX\\\\UCT04_best.pth.tar')\r\n        net.load_state_dict(ckpt['state_dict'])\r\n        epoch_state = ckpt['epoch']\r\n        total_loss_list = ckpt['total_loss']\r\n        # for i in range(len(opt.scheduler_settings['step'])):\r\n        #     opt.scheduler_settings['step'][i] = opt.scheduler_settings['step'][i] - ckpt['epoch']\r\n\r\n    ### Default settings of SCTransNet\r\n    if opt.optimizer_name == 'Adam':\r\n        opt.optimizer_settings = {'lr': 0.001}\r\n        opt.scheduler_name = 'CosineAnnealingLR'\r\n        opt.scheduler_settings = {'epochs': opt.epochs, 'eta_min': 1e-5, 'last_epoch': -1}\r\n\r\n    ### Default settings of DNANet\r\n    if opt.optimizer_name == 'Adagrad':\r\n        opt.optimizer_settings = {'lr': 0.05}\r\n        opt.scheduler_name = 'CosineAnnealingLR'\r\n        opt.scheduler_settings = {'epochs': opt.epochs, 'min_lr': 1e-5}\r\n\r\n    ### Default settings of EGEUNet\r\n    if opt.optimizer_name == 'AdamW':\r\n        opt.optimizer_settings = {'lr': 0.001, 'betas': (0.9, 0.999), \"eps\": 1e-8, \"weight_decay\": 1e-2,\r\n                                  \"amsgrad\": False}\r\n        opt.scheduler_name = 'CosineAnnealingLR'\r\n        opt.scheduler_settings = {'epochs': opt.epochs, 'T_max': 50, 'eta_min': 1e-5, 'last_epoch': -1}\r\n\r\n    opt.nEpochs = opt.scheduler_settings['epochs']\r\n\r\n    optimizer, scheduler = get_optimizer(net, opt.optimizer_name, opt.scheduler_name, opt.optimizer_settings,\r\n                                         opt.scheduler_settings)\r\n\r\n\r\n    for idx_epoch in range(epoch_state, opt.nEpochs):\r\n        net.train()\r\n        results1 = (0, 0)\r\n        results2 = (0, 0)\r\n        for idx_iter, (img, gt_mask) in enumerate(train_loader):\r\n            img, gt_mask = Variable(img).cuda(), Variable(gt_mask).cuda()\r\n            if img.shape[0] == 1:\r\n                continue\r\n            preds = net.forward(img)\r\n            loss = net.loss(preds, gt_mask)\r\n            total_loss_epoch.append(loss.detach().cpu())\r\n            optimizer.zero_grad()\r\n            loss.backward()\r\n            optimizer.step()\r\n\r\n        scheduler.step()\r\n\r\n        if (idx_epoch + 1) % opt.every_print == 0:\r\n            total_loss_list.append(float(np.array(total_loss_epoch).mean()))\r\n            print(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f, lr---%f,'\r\n                  % (idx_epoch + 1, total_loss_list[-1], scheduler.get_last_lr()[0]))\r\n            opt.f.write(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f,\\n'\r\n                        % (idx_epoch + 1, total_loss_list[-1]))\r\n            total_loss_epoch = []\r\n            # Log the scalar values\r\n            writer.add_scalar('loss', total_loss_list[-1], idx_epoch + 1)\r\n            writer.add_scalar('lr', scheduler.get_last_lr()[0], idx_epoch + 1)\r\n\r\n        # 500\r\n        if (idx_epoch + 1) >= opt.begin_test and (idx_epoch + 1) % opt.every_test == 0:\r\n            test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg)\r\n            test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)\r\n            net.eval()\r\n            with torch.no_grad():\r\n                eval_mIoU = mIoU()\r\n                eval_PD_FA = PD_FA()\r\n                test_loss = []\r\n                for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader):\r\n                    img = Variable(img).cuda()\r\n                    pred = net.forward(img)\r\n                    if isinstance(pred, tuple):\r\n                        pred = pred[-1]\r\n                    elif isinstance(pred, list):\r\n                        pred = pred[-1]\r\n                    else:\r\n                        pred = pred\r\n                    pred = pred[:, :, :size[0], :size[1]]\r\n                    gt_mask = gt_mask[:, :, :size[0], :size[1]]\r\n                    # if pred.size() != gt_mask.size():\r\n                    #     print('1111')\r\n                    loss = net.loss(pred, gt_mask.cuda())\r\n                    test_loss.append(loss.detach().cpu())\r\n                    eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu())\r\n                    eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size)\r\n                test_loss.append(float(np.array(test_loss).mean()))\r\n                results1 = eval_mIoU.get()\r\n                results2 = eval_PD_FA.get()\r\n                writer.add_scalar('mIOU', results1[-1], idx_epoch + 1)\r\n                writer.add_scalar('testloss', test_loss[-1], idx_epoch + 1)\r\n\r\n\r\n        if (idx_epoch + 1) % opt.every_save_pth == 0:\r\n            save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar'\r\n            save_checkpoint({\r\n                'epoch': idx_epoch + 1,\r\n                'state_dict': net.state_dict(),\r\n                'total_loss': total_loss_list,\r\n            }, save_pth)\r\n            test(save_pth)\r\n\r\n        if idx_epoch == 0:\r\n            best_mIOU = results1\r\n            best_Pd = results2\r\n\r\n        if results1[1] > best_mIOU[1]:\r\n            best_mIOU = results1\r\n            best_Pd = results2\r\n            print('------save the best model epoch', opt.model_name,'_%d ------' % (idx_epoch + 1))\r\n            opt.f.write(\"the best model epoch \\t\" + str(idx_epoch + 1) + '\\n')\r\n            print(\"pixAcc, mIoU:\\t\" + str(best_mIOU))\r\n            print(\"testloss:\\t\" + str(test_loss[-1]))\r\n            print(\"PD, FA:\\t\" + str(best_Pd))\r\n\r\n            opt.f.write(\"pixAcc, mIoU:\\t\" + str(best_mIOU) + '\\n')\r\n            opt.f.write(\"PD, FA:\\t\" + str(best_Pd) + '\\n')\r\n            save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '_' + 'best' + '.pth.tar'\r\n            save_checkpoint({\r\n                'epoch': idx_epoch + 1,\r\n                'state_dict': net.state_dict(),\r\n                'total_loss': total_loss_list,\r\n            }, save_pth)\r\n\r\n        # last epoch\r\n        if (idx_epoch + 1) == opt.nEpochs and (idx_epoch + 1) % opt.every_save_pth != 0:\r\n            save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar'\r\n            save_checkpoint({\r\n                'epoch': idx_epoch + 1,\r\n                'state_dict': net.state_dict(),\r\n                'total_loss': total_loss_list,\r\n            }, save_pth)\r\n            test(save_pth)\r\n\r\n\r\ndef test(save_pth):\r\n    test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg)\r\n    test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)\r\n\r\n    net = Net(model_name=opt.model_name, mode='test').cuda()\r\n    ckpt = torch.load(save_pth)\r\n    net.load_state_dict(ckpt['state_dict'])\r\n    net.eval()\r\n    with torch.no_grad():\r\n        eval_mIoU = mIoU()\r\n        eval_PD_FA = PD_FA()\r\n        test_loss_a = []\r\n        for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader):\r\n            img = Variable(img).cuda()\r\n            pred = net.forward(img)\r\n            if pred.size() != gt_mask.size():\r\n                print('1111')\r\n            pred = pred[:, :, :size[0], :size[1]]\r\n            gt_mask = gt_mask[:, :, :size[0], :size[1]]\r\n            loss = net.loss(pred, gt_mask.cuda())\r\n            test_loss_a.append(loss.detach().cpu())\r\n            eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu())\r\n            eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size)\r\n\r\n        test_loss_a.append(float(np.array(test_loss_a).mean()))\r\n        results1 = eval_mIoU.get()\r\n        results2 = eval_PD_FA.get()\r\n\r\n        print('== == == == == == == ', opt.model_name, ' == == == == == == ==')\r\n        print(\"pixAcc, mIoU:\\t\" + str(results1))\r\n        print(\"testloss:\\t\" + str(test_loss_a[-1]))\r\n        print(\"PD, FA:\\t\" + str(results2))\r\n        opt.f.write(\"pixAcc, mIoU:\\t\" + str(results1) + '\\n')\r\n        opt.f.write(\"PD, FA:\\t\" + str(results2) + '\\n')\r\n\r\n\r\ndef save_checkpoint(state, save_path):\r\n    if not os.path.exists(os.path.dirname(save_path)):\r\n        os.makedirs(os.path.dirname(save_path))\r\n    torch.save(state, save_path)\r\n    return save_path\r\n\r\n\r\nclass Net(nn.Module):\r\n    def __init__(self, model_name, mode):\r\n        super(Net, self).__init__()\r\n        self.model_name = model_name\r\n        # ************************************************loss*************************************************#\r\n        self.cal_loss = nn.BCELoss(size_average=True)\r\n        if model_name == 'SCTransNet':\r\n            if mode == 'train':\r\n                self.model = SCTransNet(config_vit, mode='train', deepsuper=True)\r\n            else:\r\n                self.model = SCTransNet(config_vit, mode='test', deepsuper=True)\r\n    def forward(self, img):\r\n        return self.model(img)\r\n\r\n    def loss(self, preds, gt_masks):\r\n        if isinstance(preds, list):\r\n            loss_total = 0\r\n            for i in range(len(preds)):\r\n                pred = preds[i]\r\n                gt_mask = gt_masks[i]\r\n                loss = self.cal_loss(pred, gt_mask)\r\n                loss_total = loss_total + loss\r\n            return loss_total / len(preds)\r\n\r\n        elif isinstance(preds, tuple):\r\n            a = []\r\n            for i in range(len(preds)):\r\n                pred = preds[i]\r\n                loss = self.cal_loss(pred, gt_masks)\r\n                a.append(loss)\r\n            loss_total = a[0] + a[1] + a[2] + a[3] + a[4] + a[5]\r\n            return loss_total\r\n\r\n        else:\r\n            loss = self.cal_loss(preds, gt_masks)\r\n            return loss\r\n\r\n\r\nif __name__ == '__main__':\r\n    for dataset_name in opt.dataset_names:\r\n        opt.dataset_name = dataset_name\r\n        for model_name in opt.model_names:\r\n            opt.model_name = model_name\r\n            if not os.path.exists(opt.save):\r\n                os.makedirs(opt.save)\r\n            opt.f = open(opt.save + '/' + opt.dataset_name + '_' + opt.model_name + '_' + (time.ctime()).replace(' ',\r\n                                                                                                                 '_').replace(\r\n                ':', '_') + '.txt', 'w')\r\n            print(opt.dataset_name + '\\t' + opt.model_name)\r\n            train()\r\n            print('\\n')\r\n            opt.f.close()\r\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\r\nimport numpy as np\r\nfrom PIL import Image\r\nfrom torchvision import transforms\r\nfrom torch.utils.data.dataset import Dataset\r\nimport random\r\nimport matplotlib.pyplot as plt\r\nimport cv2\r\nimport numpy as np\r\nimport os\r\nimport math\r\nimport torch.nn as nn\r\nfrom skimage import measure\r\nfrom warmup_scheduler import GradualWarmupScheduler\r\nimport torch.nn.functional as F\r\nimport os\r\nfrom torch.nn import init\r\n\r\nos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\r\n\r\n\r\ndef seed_pytorch(seed=42):\r\n    random.seed(seed)\r\n    os.environ['PYTHONHASHSEED'] = str(seed)\r\n    np.random.seed(seed)\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed(seed)\r\n    torch.cuda.manual_seed_all(seed)\r\n\r\n\r\ndef weights_init_xavier(m):\r\n    classname = m.__class__.__name__\r\n    if classname.find('Conv2d') != -1 and classname.find('SplAtConv2d') == -1:\r\n        init.xavier_normal(m.weight.data)\r\n\r\n\r\n# def weights_init_xavier(m):\r\n#     classname = m.__class__.__name__\r\n#     if classname.find('Conv2d') != -1:\r\n#         # init.kaiming_normal_(m.weight.data,a=0, mode='fan_in', nonlinearity='leaky_relu')\r\n#         init.xavier_normal(m.weight.data)\r\n\r\n\r\ndef weights_init_kaiming(m):\r\n    classname = m.__class__.__name__\r\n    if classname.find('Conv') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\r\n    elif classname.find('Linear') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\r\n    elif classname.find('BatchNorm') != -1:\r\n        init.normal_(m.weight.data, 1.0, 0.02)\r\n        init.constant_(m.bias.data, 0.0)\r\n\r\n\r\nclass Get_gradient_nopadding(nn.Module):\r\n    def __init__(self):\r\n        super(Get_gradient_nopadding, self).__init__()\r\n        kernel_v = [[0, -1, 0],\r\n                    [0, 0, 0],\r\n                    [0, 1, 0]]\r\n        kernel_h = [[0, 0, 0],\r\n                    [-1, 0, 1],\r\n                    [0, 0, 0]]\r\n        kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)\r\n        kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)\r\n        self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda()\r\n        self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda()\r\n\r\n    def forward(self, x):\r\n        x0 = x[:, 0]\r\n        x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)\r\n        x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)\r\n\r\n        x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)\r\n\r\n        return x0\r\n\r\n\r\ndef random_crop(img, mask, patch_size, pos_prob=None):\r\n    h, w = img.shape\r\n    if min(h, w) < patch_size:\r\n        img = np.pad(img, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)),\r\n                     mode='constant')  # 将不足 256的一边填充至256\r\n        mask = np.pad(mask, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)),\r\n                      mode='constant')  # label 与image 进行相同的变换\r\n        h, w = img.shape\r\n\r\n    while 1:\r\n        h_start = random.randint(0, h - patch_size)\r\n        h_end = h_start + patch_size\r\n        w_start = random.randint(0, w - patch_size)\r\n        w_end = w_start + patch_size\r\n\r\n        img_patch = img[h_start:h_end, w_start:w_end]\r\n        mask_patch = mask[h_start:h_end, w_start:w_end]\r\n\r\n        if pos_prob == None or random.random() > pos_prob:\r\n            break\r\n        elif mask_patch.sum() > 0:\r\n            break\r\n\r\n    return img_patch, mask_patch\r\n\r\n\r\ndef Normalized(img, img_norm_cfg):\r\n    return (img - img_norm_cfg['mean']) / img_norm_cfg['std']\r\n\r\n\r\ndef Denormalization(img, img_norm_cfg):\r\n    return img * img_norm_cfg['std'] + img_norm_cfg['mean']\r\n\r\n\r\ndef get_img_norm_cfg(dataset_name, dataset_dir):\r\n    if dataset_name == 'NUAA-SIRST':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'NUDT-SIRST':\r\n        img_norm_cfg = dict(mean=107.80905151367188, std=33.02274703979492)\r\n    elif dataset_name == 'IRSTD-1K':\r\n        img_norm_cfg = dict(mean=87.4661865234375, std=39.71953201293945)\r\n    elif dataset_name == 'SIRST2':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'SIRST3':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'NUDT-SIRST-Sea':\r\n        img_norm_cfg = dict(mean=43.62403869628906, std=18.91838264465332)\r\n    elif dataset_name == 'SIRST4':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'SIRST5':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'SIRST6':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'SIRST7':\r\n        img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)\r\n    elif dataset_name == 'IRDST-real':\r\n        img_norm_cfg = {'mean': 101.54053497314453, 'std': 56.49856185913086}\r\n    else:\r\n        with open(dataset_dir + '/' + dataset_name + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:\r\n            train_list = f.read().splitlines()\r\n        with open(dataset_dir + '/' + dataset_name + '/img_idx/test_' + dataset_name + '.txt', 'r') as f:\r\n            test_list = f.read().splitlines()\r\n        img_list = train_list + test_list\r\n        img_dir = dataset_dir + '/' + dataset_name + '/images/'\r\n        mean_list = []\r\n        std_list = []\r\n        for img_pth in img_list:\r\n            try:\r\n                img = Image.open((img_dir + img_pth).replace('//', '/') + '.png').convert('I')\r\n            except:\r\n                try:\r\n                    img = Image.open((img_dir + img_pth).replace('//', '/') + '.jpg').convert('I')\r\n                except:\r\n                    img = Image.open((img_dir + img_pth).replace('//', '/') + '.bmp').convert('I')\r\n            img = np.array(img, dtype=np.float32)\r\n            mean_list.append(img.mean())\r\n            std_list.append(img.std())\r\n        img_norm_cfg = dict(mean=float(np.array(mean_list).mean()), std=float(np.array(std_list).mean()))\r\n    return img_norm_cfg\r\n\r\n\r\ndef get_optimizer(net, optimizer_name, scheduler_name, optimizer_settings, scheduler_settings):\r\n    if optimizer_name == 'Adam':\r\n        optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'])\r\n    if optimizer_name == 'Adamweight':\r\n        optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'], weight_decay=1e-3)\r\n\r\n    elif optimizer_name == 'Adagrad':\r\n        optimizer = torch.optim.Adagrad(net.parameters(), lr=optimizer_settings['lr'])\r\n    elif optimizer_name == 'SGD':\r\n        optimizer = torch.optim.SGD(net.parameters(), lr=optimizer_settings['lr'],\r\n                                    momentum=0.9,\r\n                                    weight_decay=scheduler_settings['weight_decay'])\r\n    # elif optimizer_name == 'AdamW':\r\n    #     optimizer = torch.optim.AdamW(net.parameters(), lr=optimizer_settings['lr'], betas=optimizer_settings['betas'],\r\n    #                                   eps=optimizer_settings['eps'], weight_decay=optimizer_settings['weight_decay'],\r\n    #                                   amsgrad=optimizer_settings['amsgrad'])\r\n\r\n    if scheduler_name == 'MultiStepLR':\r\n        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_settings['step'],\r\n                                                         gamma=scheduler_settings['gamma'])\r\n    # elif scheduler_name == 'DNACosineAnnealingLR':\r\n    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'],\r\n    #                                                            eta_min=scheduler_settings['eta_min'])\r\n    elif scheduler_name == 'CosineAnnealingLR':\r\n        warmup_epochs = 10\r\n        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,\r\n                                                                      eta_min=scheduler_settings['eta_min'])\r\n        scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,\r\n                                           after_scheduler=scheduler_cosine)\r\n    elif scheduler_name == 'CosineAnnealingLRw50':\r\n        warmup_epochs = 50\r\n        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,\r\n                                                                      eta_min=scheduler_settings['eta_min'])\r\n        scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,\r\n                                           after_scheduler=scheduler_cosine)\r\n\r\n    elif scheduler_name == 'CosineAnnealingLRw0':\r\n        # warmup_epochs = 0\r\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min'])\r\n        # scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,\r\n        #                                                               eta_min=1e-5)\r\n        # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,\r\n        #                                    after_scheduler=scheduler_cosine)\r\n\r\n        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['T_max'],\r\n        #                                                        eta_min=scheduler_settings['eta_min'],\r\n        #                                                        last_epoch=scheduler_settings['last_epoch'])\r\n        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min'])\r\n\r\n    return optimizer, scheduler\r\n\r\n\r\ndef PadImg(img, times=32):\r\n    h, w = img.shape\r\n    if not h % times == 0:\r\n        img = np.pad(img, ((0, (h // times + 1) * times - h), (0, 0)), mode='constant')\r\n    if not w % times == 0:\r\n        img = np.pad(img, ((0, 0), (0, (w // times + 1) * times - w)), mode='constant')\r\n    return img\r\n\r\n\r\n\r\n\r\n"
  },
  {
    "path": "warmup_scheduler.py",
    "content": "from torch.optim.lr_scheduler import _LRScheduler\r\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\r\n\r\nclass GradualWarmupScheduler(_LRScheduler):\r\n    \"\"\" Gradually warm-up(increasing) learning rate in optimizer.\r\n    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.\r\n    在optimizer中会设置一个基础学习率base lr,\r\n    当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler\r\n    当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler\r\n    Args:\r\n        optimizer (Optimizer): Wrapped optimizer.\r\n        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.\r\n        total_epoch: target learning rate is reached at total_epoch, gradually\r\n        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)\r\n    \"\"\"\r\n\r\n    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):\r\n        self.multiplier = multiplier\r\n        if self.multiplier < 1.:\r\n            raise ValueError('multiplier should be greater thant or equal to 1.')\r\n        self.total_epoch = total_epoch\r\n        self.after_scheduler = after_scheduler\r\n        self.finished = False\r\n        super(GradualWarmupScheduler, self).__init__(optimizer)\r\n\r\n    def get_lr(self):\r\n        if self.last_epoch > self.total_epoch:\r\n            if self.after_scheduler and (not self.finished):\r\n                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]\r\n                    self.finished = True\r\n                # !这是很关键的一个环节，需要直接返回新的base-lr\r\n            return [base_lr for base_lr in self.after_scheduler.base_lrs]\r\n        if self.multiplier == 1.0:\r\n            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]\r\n        else:\r\n            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]\r\n\r\n    def step_ReduceLROnPlateau(self, metrics, epoch=None):\r\n        if epoch is None:\r\n            epoch = self.last_epoch + 1\r\n        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning\r\n        print('warmuping...')\r\n        if self.last_epoch <= self.total_epoch:\r\n            warmup_lr=None\r\n            if self.multiplier == 1.0:\r\n                warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]\r\n            else:\r\n                warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]\r\n            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):\r\n                param_group['lr'] = lr\r\n        else:\r\n            if epoch is None:\r\n                self.after_scheduler.step(metrics, None)\r\n            else:\r\n                self.after_scheduler.step(metrics,epoch - self.total_epoch)\r\n\r\n    def step(self, epoch=None, metrics=None):\r\n        if type(self.after_scheduler) != ReduceLROnPlateau:\r\n            if self.finished and self.after_scheduler:\r\n                if epoch is None:\r\n                    self.after_scheduler.step(None)\r\n                else:\r\n                    self.after_scheduler.step(epoch - self.total_epoch)\r\n                self._last_lr = self.after_scheduler.get_last_lr()\r\n            else:\r\n                return super(GradualWarmupScheduler, self).step(epoch)\r\n        else:\r\n            self.step_ReduceLROnPlateau(metrics, epoch)\r\n"
  }
]