Repository: xdFai/SCTransNet Branch: main Commit: e0276283794e Files: 11 Total size: 124.0 KB Directory structure: gitextract_03y9khf8/ ├── README.md ├── dataset.py ├── datasets/ │ └── SIRST3/ │ └── img_idx/ │ ├── test_SIRST3.txt │ └── train_SIRST3.txt ├── metrics.py ├── model/ │ ├── Config.py │ └── SCTransNet.py ├── test.py ├── train.py ├── utils.py └── warmup_scheduler.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # SCTransNet: Spatial-channel Cross Transformer Network for Infrared Small Target Detection [[Paper]](https://ieeexplore.ieee.org/document/10486932) [[Weight]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing) Shuai Yuan, Hanlin Qin, Xiang Yan, Naveed Akhtar, Aimal Main, IEEE Transactions on Geoscience and Remote Sensing 2024. # SCTransNet 是PRCV 2024、ICPR 2024 Track 1、ICPR 2024 Track 2 三项比赛冠军方案的 Baseline, 同时也是多个优胜算法的Baselines. [[Paper]](https://arxiv.org/abs/2408.09615) # Bilibili 视频分享 https://www.bilibili.com/video/BV1kr421M7wx/ # 极市平台 推文分享 https://mp.weixin.qq.com/s/H7KLmtFX7j09f-Xc6X1FRw # If the implementation of this repo is helpful to you, just star it!⭐⭐⭐ # Challenges and inspiration ![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture01.png) # Structure ![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture2.png) ![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture03.png) # Introduction We present a Spatial-channel Cross Transformer Network (SCTransNet) to the IRSTD task. Experiments on both public (e.g., SIRST, NUDT-SIRST, IRSTD-1K) demonstrate the effectiveness of our method. Our main contributions are as follows: 1. We propose SCTransNet, leveraging spatial-channel cross transformer blocks (SCTB) to predict the context of targets and backgrounds in the deeper network layers. 2. A spatial-embedded single-head channel-cross attention (SSCA) module is utilized to foster semantic interactions across all feature levels and learn the long-range context. 3. We devise a novel complementary feed-forward network (CFN) by crossing spatial-channel information to enhance the semantic difference between the target and background. ## Usage #### 1. Data The **SIRST3** dataset, which combines **IRSTD-1K**, **NUDT-SIRST**, and **SIRST-v1**, is used to train SCTransNet. * **SIRST-v1**   [[download]](https://github.com/YimianDai/sirst)   [[paper]](https://arxiv.org/pdf/2009.14530.pdf) * **NUDT-SIRST**   [[download]](https://github.com/YeRen123455/Infrared-Small-Target-Detection)   [[paper]](https://ieeexplore.ieee.org/abstract/document/9864119) * **IRSTD-1K**   [[download dir]](https://github.com/RuiZhang97/ISNet)   [[paper]](https://ieeexplore.ieee.org/document/9880295) * Apologies for misnaming the **SIRST-v1** dataset as **NUAA-SIRST** in both the article and code. We will follow the original authors’ naming convention in future work. * **Our project has the following structure:** ``` ├──./datasets/ │ ├── IRSTD-1K │ │ ├── images │ │ │ ├── XDU0.png │ │ │ ├── XDU1.png │ │ │ ├── ... │ │ ├── masks │ │ │ ├── XDU0.png │ │ │ ├── XDU1.png │ │ │ ├── ... │ │ ├── img_idx │ │ │ ├── train_IRSTD-1K.txt │ │ │ ├── test_IRSTD-1K.txt │ ├── NUDT-SIRST │ │ ├── images │ │ │ ├── 000001.png │ │ │ ├── 000002.png │ │ │ ├── ... │ │ ├── masks │ │ │ ├── 000001.png │ │ │ ├── 000002.png │ │ │ ├── ... │ │ ├── img_idx │ │ │ ├── train_NUDT-SIRST.txt │ │ │ ├── test_NUDT-SIRST.txt │ ├── SIRSTv1 (~which is misnamed as NUAA-SIRST~) │ │ ├── images │ │ │ ├── Misc_1.png │ │ │ ├── Misc_2.png │ │ │ ├── ... │ │ ├── masks │ │ │ ├── Misc_1.png │ │ │ ├── Misc_2.png │ │ │ ├── ... │ │ ├── img_idx │ │ │ ├── train_NUAA-SIRST.txt │ │ │ ├── test_NUAA-SIRST.txt │ ├── SIRST3 (~The sum of SIRSTv1, NUDT-SIRST and IRSTD-1K~) │ │ ├── images │ │ │ ├── XDU0.png │ │ │ ├── XDU1.png │ │ │ ├── ... │ │ ├── masks │ │ │ ├── XDU0.png │ │ │ ├── XDU1.png │ │ │ ├── ... │ │ ├── img_idx │ │ │ ├── train_SIRST3.txt │ │ │ ├── test_SIRST3.txt ``` ##### 2. Train. ```bash python train.py ``` #### 3. Test and demo. 权重文件的百度网盘链接:[https://pan.baidu.com/s/1_hlEaqnJI246GWN4N8k8wA?pwd=t28j](https://pan.baidu.com/s/1B0mANHXSfJaQjHr00XIwgQ?pwd=s7nh) 权重文件的谷歌云盘链接:https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing ```bash python test.py ``` ## Results and Trained Models #### Qualitative Results ![Image text](https://github.com/xdFai/SCTransNet/blob/main/Fig/picture06.png) #### Quantitative Results on Mixed SIRSTv1, NUDT-SIRST, and IRSTD-1K. i.e, one weight for three Datasets. | Model | mIoU (x10(-2)) | nIoU (x10(-2)) | F-measure (x10(-2))| Pd (x10(-2))| Fa (x10(-6))| | ------------- |:-------------:|:-----:|:-----:|:-----:|:-----:| | SIRSTv1 | 77.50 | 81.08 | 87.32 | 96.95 | 13.92 | | NUDT-SIRST | 94.09 | 94.38 | 96.95 | 98.62 | 4.29 | | IRSTD-1K | 68.03 | 68.15 | 80.96 | 93.27 | 10.74 | | [[Weights]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing)| *This code is highly borrowed from [IRSTD-Toolbox](https://github.com/XinyiYing/BasicIRSTD). Thanks to Xinyi Ying. *This code is highly borrowed from [UCTransNet](https://github.com/McGregorWwww/UCTransNet). Thanks to Haonan Wang. *The overall repository style is highly borrowed from [DNA-Net](https://github.com/YeRen123455/Infrared-Small-Target-Detection). Thanks to Boyang Li. ## Citation If you find the code useful, please consider citing our paper using the following BibTeX entry. ``` @ARTICLE{SCTransNet, author={Yuan, Shuai and Qin, Hanlin and Yan, Xiang and Akhtar, Naveed and Mian, Ajmal}, journal={IEEE Transactions on Geoscience and Remote Sensing}, title={SCTransNet: Spatial-Channel Cross Transformer Network for Infrared Small Target Detection}, year={2024}, volume={62}, number={}, pages={1-15}, keywords={Semantics;Transformers;Decoding;Feature extraction;Task analysis;Object detection;Visualization;Convolutional neural network (CNN);cross-attention;deep learning;infrared small target detection (IRSTD);transformer}, doi={10.1109/TGRS.2024.3383649}} @article{SP-KAN, title = {SP-KAN: Sparse-sine perception Kolmogorov–Arnold networks for infrared small target detection}, journal = {ISPRS Journal of Photogrammetry and Remote Sensing}, volume = {234}, pages = {1-19}, year = {2026}, issn = {0924-2716}, doi = {https://doi.org/10.1016/j.isprsjprs.2026.02.019}, url = {https://www.sciencedirect.com/science/article/pii/S0924271626000705}, author = {Shuai Yuan and Yu Liu and Xiaopei Zhang and Xiang Yan and Hanlin Qin and Naveed Akhtar}, } ``` ## Contact **Welcome to raise issues or email to [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our SCTransNet.** ================================================ FILE: dataset.py ================================================ from utils import * import matplotlib.pyplot as plt import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' class TrainSetLoader(Dataset): def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None): super(TrainSetLoader).__init__() self.dataset_name = dataset_name self.dataset_dir = dataset_dir + '/' + dataset_name self.patch_size = patch_size with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f: self.train_list = f.read().splitlines() if img_norm_cfg == None: self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir) else: self.img_norm_cfg = img_norm_cfg self.tranform = augumentation() def __getitem__(self, idx): try: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert( 'I') # read image base on version ”I“ # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/')) mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/')) except: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/')) img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize mask = np.array(mask, dtype=np.float32) / 255.0 if len(mask.shape) > 2: mask = mask[:, :, 0] # rnd_bn = np.random.normal(0, 0.03)#0.03 # img += rnd_bn # # minm = img.min() # rng = img.max() - minm # gamma = np.random.uniform(0.5, 1.6) # x=((img - minm) / rng) # img = np.power(x, gamma) # img = img * rng + minm img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256 img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强 img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维 img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor return img_patch, mask_patch def __len__(self): return len(self.train_list) class TrainSetLoader02(Dataset): def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None): super(TrainSetLoader).__init__() self.dataset_name = dataset_name self.dataset_dir = dataset_dir + '/' + dataset_name self.patch_size = patch_size with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f: self.train_list = f.read().splitlines() if img_norm_cfg == None: self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir) else: self.img_norm_cfg = img_norm_cfg self.tranform = augumentation() def __getitem__(self, idx): try: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert( 'I') # read image base on version ”I“ # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/')) mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/')) except: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/')) img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize mask = np.array(mask, dtype=np.float32) / 255.0 if len(mask.shape) > 2: mask = mask[:, :, 0] rnd_bn = np.random.normal(0, 0.03)#0.03 img += rnd_bn # # minm = img.min() # rng = img.max() - minm # gamma = np.random.uniform(0.5, 1.6) # x=((img - minm) / rng) # img = np.power(x, gamma) # img = img * rng + minm img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256 img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强 img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维 img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor return img_patch, mask_patch def __len__(self): return len(self.train_list) class TrainSetLoader03(Dataset): def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None): super(TrainSetLoader).__init__() self.dataset_name = dataset_name self.dataset_dir = dataset_dir + '/' + dataset_name self.patch_size = patch_size with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f: self.train_list = f.read().splitlines() if img_norm_cfg == None: self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir) else: self.img_norm_cfg = img_norm_cfg self.tranform = augumentation() def __getitem__(self, idx): try: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert( 'I') # read image base on version ”I“ # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/')) mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/')) except: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/')) img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize mask = np.array(mask, dtype=np.float32) / 255.0 if len(mask.shape) > 2: mask = mask[:, :, 0] # rnd_bn = np.random.normal(0, 0.03)#0.03 # img += rnd_bn minm = img.min() rng = img.max() - minm gamma = np.random.uniform(0.5, 1.6) x=((img - minm) / rng) img = np.power(x, gamma) img = img * rng + minm img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256 img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强 img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维 img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor return img_patch, mask_patch def __len__(self): return len(self.train_list) class TrainSetLoader04(Dataset): def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None): super(TrainSetLoader).__init__() self.dataset_name = dataset_name self.dataset_dir = dataset_dir + '/' + dataset_name self.patch_size = patch_size with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f: self.train_list = f.read().splitlines() if img_norm_cfg == None: self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir) else: self.img_norm_cfg = img_norm_cfg self.tranform = augumentation() def __getitem__(self, idx): try: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert( 'I') # read image base on version ”I“ # img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/')) mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/')) except: img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/')) img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize mask = np.array(mask, dtype=np.float32) / 255.0 if len(mask.shape) > 2: mask = mask[:, :, 0] rnd_bn = np.random.normal(0, 0.03)#0.03 img += rnd_bn minm = img.min() rng = img.max() - minm gamma = np.random.uniform(0.5, 1.6) x=((img - minm) / rng) img = np.power(x, gamma) img = img * rng + minm img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256 img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强 img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维 img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor return img_patch, mask_patch def __len__(self): return len(self.train_list) class TestSetLoader(Dataset): def __init__(self, dataset_dir, train_dataset_name, test_dataset_name, img_norm_cfg=None): super(TestSetLoader).__init__() self.dataset_dir = dataset_dir + '/' + test_dataset_name with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f: # with open(r'D:\05TGARS\Upload\datasets\SIRST3\img_idx\val.txt', 'r') as f: self.test_list = f.read().splitlines() if img_norm_cfg == None: self.img_norm_cfg = get_img_norm_cfg(train_dataset_name, dataset_dir) else: self.img_norm_cfg = img_norm_cfg def __getitem__(self, idx): try: img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.png').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.png').replace('//', '/')) except: img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.bmp').replace('//', '/')).convert('I') mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.bmp').replace('//', '/')) img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) mask = np.array(mask, dtype=np.float32) / 255.0 # if mask.shape == (416,608): # print('111') if len(mask.shape) > 2: mask = mask[:, :, 0] h, w = img.shape img = PadImg(img) mask = PadImg(mask) img, mask = img[np.newaxis, :], mask[np.newaxis, :] img = torch.from_numpy(np.ascontiguousarray(img)) mask = torch.from_numpy(np.ascontiguousarray(mask)) if img.size() != mask.size(): print('111') return img, mask, [h, w], self.test_list[idx] def __len__(self): return len(self.test_list) class EvalSetLoader(Dataset): def __init__(self, dataset_dir, mask_pred_dir, test_dataset_name, model_name): super(EvalSetLoader).__init__() self.dataset_dir = dataset_dir self.mask_pred_dir = mask_pred_dir self.test_dataset_name = test_dataset_name self.model_name = model_name with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f: self.test_list = f.read().splitlines() def __getitem__(self, idx): mask_pred = Image.open( (self.mask_pred_dir + self.test_dataset_name + '/' + self.model_name + '/' + self.test_list[idx] + '.png').replace('//', '/')) mask_gt = Image.open(self.dataset_dir + '/masks/' + self.test_list[idx] + '.png') mask_pred = np.array(mask_pred, dtype=np.float32) / 255.0 mask_gt = np.array(mask_gt, dtype=np.float32) / 255.0 if len(mask_pred.shape) == 3: mask_pred = mask_pred[:, :, 0] h, w = mask_pred.shape mask_pred, mask_gt = mask_pred[np.newaxis, :], mask_gt[np.newaxis, :] mask_pred = torch.from_numpy(np.ascontiguousarray(mask_pred)) mask_gt = torch.from_numpy(np.ascontiguousarray(mask_gt)) return mask_pred, mask_gt, [h, w] def __len__(self): return len(self.test_list) class augumentation(object): def __call__(self, input, target): if random.random() < 0.5: # 水平反转 input = input[::-1, :] target = target[::-1, :] if random.random() < 0.5: # 垂直反转 input = input[:, ::-1] target = target[:, ::-1] if random.random() < 0.5: # 转置反转 input = input.transpose(1, 0) target = target.transpose(1, 0) return input, target ================================================ FILE: datasets/SIRST3/img_idx/test_SIRST3.txt ================================================ Misc_338 Misc_379 Misc_422 Misc_73 Misc_321 Misc_162 Misc_372 Misc_185 Misc_420 Misc_143 Misc_137 Misc_224 Misc_274 Misc_156 Misc_121 Misc_240 Misc_82 Misc_230 Misc_51 Misc_203 Misc_210 Misc_219 Misc_178 Misc_416 Misc_366 Misc_306 Misc_94 Misc_5 Misc_413 Misc_279 Misc_110 Misc_298 Misc_257 Misc_199 Misc_368 Misc_34 Misc_248 Misc_303 Misc_99 Misc_225 Misc_317 Misc_308 Misc_66 Misc_266 Misc_123 Misc_288 Misc_348 Misc_59 Misc_299 Misc_28 Misc_227 Misc_54 Misc_287 Misc_252 Misc_21 Misc_130 Misc_228 Misc_177 Misc_24 Misc_395 Misc_332 Misc_207 Misc_30 Misc_173 Misc_49 Misc_48 Misc_75 Misc_427 Misc_411 Misc_205 Misc_57 Misc_273 Misc_272 Misc_125 Misc_262 Misc_14 Misc_56 Misc_234 Misc_412 Misc_275 Misc_350 Misc_418 Misc_113 Misc_148 Misc_357 Misc_239 Misc_385 Misc_154 Misc_83 Misc_222 Misc_111 Misc_153 Misc_309 Misc_374 Misc_312 Misc_117 Misc_197 Misc_292 Misc_145 Misc_376 Misc_122 Misc_101 Misc_394 Misc_136 Misc_289 Misc_131 Misc_25 Misc_74 Misc_347 Misc_15 Misc_1 Misc_361 Misc_373 Misc_151 Misc_97 Misc_3 Misc_381 Misc_400 Misc_277 Misc_326 Misc_38 Misc_241 Misc_334 Misc_397 Misc_164 Misc_300 Misc_297 Misc_414 Misc_249 Misc_91 Misc_107 Misc_68 Misc_259 Misc_342 Misc_189 Misc_76 Misc_235 Misc_335 Misc_22 Misc_319 Misc_314 Misc_79 Misc_346 Misc_251 Misc_387 Misc_410 Misc_345 Misc_236 Misc_41 Misc_35 Misc_9 Misc_396 Misc_375 Misc_176 Misc_188 Misc_44 Misc_356 Misc_80 Misc_60 Misc_386 Misc_126 Misc_206 Misc_307 Misc_208 Misc_276 Misc_286 Misc_193 Misc_270 Misc_377 Misc_37 Misc_390 Misc_106 Misc_255 Misc_328 Misc_268 Misc_301 Misc_181 Misc_371 Misc_204 Misc_323 Misc_87 Misc_53 Misc_293 Misc_20 Misc_85 Misc_327 Misc_216 Misc_365 Misc_408 Misc_62 Misc_271 Misc_265 Misc_325 Misc_340 Misc_155 Misc_10 Misc_13 Misc_124 Misc_140 Misc_139 Misc_46 Misc_19 Misc_135 Misc_221 Misc_212 Misc_359 Misc_237 Misc_331 Misc_69 Misc_186 Misc_294 Misc_196 Misc_84 Misc_380 000541 000832 000710 000316 001142 000409 000226 001133 001255 000558 001000 000763 000385 001319 000591 000484 001236 000334 000915 000523 000765 000215 000045 001303 000644 001190 001026 001089 000720 000195 000519 001111 000922 000965 000080 000751 000281 000849 000495 001257 000966 001267 000091 000935 000314 000128 001169 000701 000903 000144 000613 000313 000552 000787 000536 000423 000389 001235 000615 000366 000399 001155 000625 000610 001050 001148 000440 000032 000812 000830 000771 000324 001161 000209 000645 000897 001036 000200 000708 000474 000840 000554 000011 000603 000560 000550 000498 000243 000433 000106 000036 000574 000837 000608 000341 000863 000907 000600 000361 000555 000636 000509 000799 000062 000297 000933 000374 000928 000910 000131 000191 000216 000373 000556 000417 000485 000557 000153 000955 000511 001038 001239 000986 000740 001018 001068 000902 001221 000510 000212 000026 001320 001074 000189 000335 000295 001265 001206 000637 000349 001122 001178 000944 001135 000113 000041 000575 000757 000882 000917 000909 000569 001167 001212 000801 000466 000449 000913 000607 001218 001096 000850 001151 000426 000908 000286 001280 000167 000332 000499 001066 000611 001188 001238 001195 000980 000194 000315 000391 000704 000930 000150 000687 000394 000988 000535 001278 000867 000371 000355 000722 000178 001215 000747 001162 000001 000405 001097 000506 001102 000326 000737 000302 000699 000715 000874 000713 001244 000589 001144 001213 000689 000266 000563 000395 000415 000749 000330 000300 000071 001264 000237 000973 001076 000666 001079 000532 000820 000207 000182 001176 001272 000047 000268 000906 000351 000633 000025 000957 000176 000220 000188 001006 000926 001129 000599 000684 000155 000703 001143 000620 000075 000839 001300 001186 000197 000811 000601 001040 000735 001029 000623 000436 001242 000813 000493 001053 000410 000936 000473 000478 000458 000709 001307 000781 001023 001004 000359 000287 000482 000941 000067 000284 000447 000894 000163 000059 001222 001268 000744 000921 000173 000378 000806 000629 000870 000860 000660 000664 000533 000617 001145 000961 000940 000365 001003 000305 001317 001314 000352 000898 000883 001309 000732 000336 000622 001232 000166 000392 000231 001281 001199 000893 000879 000084 001174 000377 001063 000628 000537 000772 000100 000043 000291 000918 000548 000729 000475 000731 000974 001047 000595 000508 001101 001248 001107 001127 001253 000920 000021 000742 000140 000320 001209 000434 000120 000711 000137 000249 000995 001084 001273 000845 000452 001220 000553 000003 000984 000086 000967 000788 000416 000138 000112 001114 001308 000053 000762 000333 001081 001287 000821 001302 000480 000775 000525 001262 000124 001069 000782 000490 000010 000465 001286 001060 000685 000717 000939 000027 000971 000727 000170 000187 000671 000976 000562 000651 000831 000397 000157 000682 001227 001010 000294 000159 001088 001241 001229 000289 000169 000135 000090 000456 000691 000382 000483 000528 000889 000060 001326 000520 000269 000706 001020 000502 000492 001153 000521 000496 000538 000009 001294 000542 000127 000192 001237 000406 001055 001251 001315 000151 001180 001250 000901 000816 000273 000545 000912 000583 000285 000678 000992 000386 000022 000573 001075 000927 001304 000219 001291 000786 000348 000861 000785 000875 001191 000853 001008 000817 000721 000807 000122 001083 000261 000588 000587 000880 001311 000835 001021 000714 000081 000420 001103 000142 000270 001160 001224 000792 000019 000470 000529 001141 000646 000656 000815 000885 000924 001254 000111 000598 001120 001087 000020 000214 000724 000582 001234 000983 000759 001149 001044 000665 000107 001322 001210 000448 000260 000579 000996 001246 000698 000783 000240 000862 000616 000051 000400 000455 000809 001013 000110 000634 000838 001184 001077 000694 000412 000836 000344 001240 000259 000439 000766 001202 000093 000507 000303 000262 000092 000890 000789 000784 000272 000186 000500 000734 000887 000476 000960 000202 001034 001214 000516 000808 000346 000680 001289 001042 000802 001139 000640 000803 000329 001123 000761 001028 000650 000133 000592 000148 000866 000230 001012 000055 000252 000421 000425 000596 000296 001325 000255 001092 000362 000450 001183 000046 000648 000370 001128 001175 000730 001154 000963 001098 000654 001112 000319 000398 001298 001057 001125 000012 000756 000141 000931 000581 000463 001259 000985 000076 000632 000162 000954 000779 000233 000614 001126 001288 001283 000790 000023 000380 000661 000923 000225 000156 000390 000597 000468 000606 000307 000630 001223 000038 000224 000673 001193 000016 000015 001110 001204 000683 000353 XDU189 XDU935 XDU672 XDU231 XDU818 XDU888 XDU146 XDU48 XDU492 XDU241 XDU195 XDU801 XDU104 XDU637 XDU996 XDU482 XDU406 XDU889 XDU558 XDU117 XDU777 XDU134 XDU223 XDU943 XDU762 XDU662 XDU54 XDU685 XDU167 XDU489 XDU505 XDU527 XDU817 XDU253 XDU193 XDU597 XDU151 XDU404 XDU596 XDU97 XDU321 XDU279 XDU93 XDU205 XDU9 XDU219 XDU674 XDU501 XDU316 XDU343 XDU885 XDU426 XDU485 XDU850 XDU516 XDU216 XDU160 XDU176 XDU504 XDU883 XDU244 XDU919 XDU781 XDU369 XDU398 XDU441 XDU75 XDU240 XDU805 XDU108 XDU709 XDU352 XDU747 XDU209 XDU845 XDU557 XDU775 XDU56 XDU657 XDU753 XDU788 XDU682 XDU794 XDU877 XDU421 XDU733 XDU546 XDU999 XDU5 XDU63 XDU966 XDU922 XDU789 XDU295 XDU863 XDU578 XDU743 XDU46 XDU115 XDU876 XDU932 XDU289 XDU855 XDU933 XDU517 XDU329 XDU3 XDU451 XDU694 XDU878 XDU259 XDU708 XDU442 XDU829 XDU833 XDU648 XDU381 XDU868 XDU803 XDU673 XDU415 XDU997 XDU667 XDU968 XDU169 XDU525 XDU164 XDU704 XDU711 XDU111 XDU354 XDU927 XDU758 XDU87 XDU697 XDU957 XDU49 XDU563 XDU954 XDU45 XDU429 XDU902 XDU302 XDU523 XDU41 XDU816 XDU785 XDU759 XDU872 XDU185 XDU881 XDU447 XDU129 XDU614 XDU920 XDU334 XDU257 XDU892 XDU103 XDU698 XDU862 XDU33 XDU416 XDU40 XDU715 XDU203 XDU589 XDU142 XDU50 XDU455 XDU620 XDU67 XDU371 XDU192 XDU28 XDU43 XDU661 XDU692 XDU463 XDU745 XDU258 XDU842 XDU459 XDU147 XDU319 XDU225 XDU178 XDU567 XDU925 XDU394 XDU110 XDU663 XDU376 XDU450 XDU10 XDU955 XDU374 XDU278 XDU393 XDU570 XDU217 ================================================ FILE: datasets/SIRST3/img_idx/train_SIRST3.txt ================================================ Misc_119 Misc_64 Misc_90 Misc_364 Misc_250 Misc_351 Misc_39 Misc_313 Misc_179 Misc_344 Misc_421 Misc_398 Misc_417 Misc_95 Misc_339 Misc_426 Misc_269 Misc_316 Misc_419 Misc_144 Misc_149 Misc_146 Misc_31 Misc_58 Misc_4 Misc_264 Misc_283 Misc_284 Misc_150 Misc_220 Misc_133 Misc_77 Misc_70 Misc_425 Misc_195 Misc_304 Misc_329 Misc_65 Misc_167 Misc_174 Misc_202 Misc_157 Misc_96 Misc_320 Misc_369 Misc_109 Misc_16 Misc_40 Misc_295 Misc_147 Misc_247 Misc_423 Misc_152 Misc_100 Misc_263 Misc_352 Misc_233 Misc_190 Misc_392 Misc_281 Misc_358 Misc_163 Misc_132 Misc_405 Misc_159 Misc_12 Misc_367 Misc_172 Misc_401 Misc_138 Misc_104 Misc_86 Misc_160 Misc_242 Misc_7 Misc_305 Misc_243 Misc_399 Misc_363 Misc_61 Misc_129 Misc_330 Misc_134 Misc_315 Misc_180 Misc_244 Misc_63 Misc_391 Misc_42 Misc_404 Misc_29 Misc_238 Misc_285 Misc_214 Misc_93 Misc_253 Misc_402 Misc_50 Misc_291 Misc_128 Misc_267 Misc_115 Misc_337 Misc_370 Misc_158 Misc_114 Misc_388 Misc_170 Misc_354 Misc_36 Misc_424 Misc_336 Misc_393 Misc_229 Misc_108 Misc_105 Misc_406 Misc_2 Misc_324 Misc_47 Misc_200 Misc_187 Misc_33 Misc_72 Misc_384 Misc_120 Misc_322 Misc_360 Misc_192 Misc_112 Misc_142 Misc_403 Misc_169 Misc_223 Misc_213 Misc_161 Misc_256 Misc_141 Misc_78 Misc_296 Misc_6 Misc_258 Misc_231 Misc_52 Misc_183 Misc_362 Misc_102 Misc_88 Misc_343 Misc_341 Misc_118 Misc_165 Misc_280 Misc_17 Misc_290 Misc_67 Misc_382 Misc_191 Misc_166 Misc_8 Misc_45 Misc_415 Misc_349 Misc_98 Misc_127 Misc_184 Misc_310 Misc_198 Misc_254 Misc_211 Misc_103 Misc_232 Misc_218 Misc_89 Misc_201 Misc_11 Misc_168 Misc_215 Misc_383 Misc_333 Misc_245 Misc_55 Misc_27 Misc_226 Misc_116 Misc_378 Misc_355 Misc_302 Misc_209 Misc_32 Misc_23 Misc_261 Misc_182 Misc_282 Misc_409 Misc_260 Misc_194 Misc_407 Misc_175 Misc_278 Misc_26 Misc_246 Misc_217 Misc_311 Misc_43 Misc_353 Misc_81 Misc_18 Misc_318 Misc_389 Misc_171 Misc_71 Misc_92 001137 000345 000774 000593 001002 001024 001285 000649 000282 000037 001150 001124 001185 000718 000968 001216 000494 001015 000945 000446 000193 000746 000733 000621 001131 000085 000064 000312 000213 001297 001039 000066 000002 000843 000659 000298 000227 000951 000299 000937 000547 000948 000859 000058 000422 001140 000873 001249 000822 000745 000609 000881 001271 000851 000848 001016 001100 000004 001208 000099 000073 000049 000158 001279 000063 001181 000515 001054 000844 001177 000934 000663 000841 000205 000911 000211 000267 000379 000168 000301 001061 001299 000818 001070 001305 001095 000695 000375 000061 000669 000325 000864 000174 000347 000748 000946 001274 001059 000253 000679 001041 000999 000794 000576 001094 000635 000825 000013 000627 000487 001201 000457 000916 000030 000109 000693 000183 000467 001164 001258 000970 000263 000858 000668 000526 000736 001146 000247 000367 000050 000773 000318 001163 001159 001306 001310 000571 001031 000461 000652 000444 001194 001156 000014 000250 000134 001301 000768 000823 000040 000891 000561 001200 001233 000119 001324 000793 000293 000143 001011 000834 000095 000017 001116 000653 000755 000705 000723 000738 001266 000364 000791 000048 000228 000956 000570 000116 000350 000276 001022 000369 001225 000129 000947 000800 000154 001168 000602 000210 001284 000534 000814 000846 001132 000257 000688 001056 000430 000900 000471 000833 000171 000567 001043 000798 001085 000895 001121 000979 000363 000311 000229 000234 001073 000716 001005 001119 001327 000754 000245 000780 000277 001052 000856 000871 000531 000472 000083 000274 000147 000096 000152 001158 000453 000057 000411 000388 000082 000655 000962 001086 000539 000381 000527 000309 000065 000418 000306 000469 000088 001007 000690 000686 000428 000605 001067 001025 001245 001045 000767 000728 001065 000540 000658 000549 000522 000978 000317 000117 000184 000631 000459 000804 000819 001048 000271 000358 000670 000145 000118 000239 000126 000115 001171 000497 000847 000462 000604 000223 000810 000938 000460 001001 001108 000826 001252 000981 000429 001277 001080 000707 000221 000265 000672 000752 000481 000384 000236 000125 000892 000238 000639 000914 000432 000489 000513 000114 000719 000643 000994 000712 000242 001207 000201 001312 000524 000198 000543 000028 001323 000331 000222 000869 000692 000842 000795 001078 000896 000052 001256 000292 000070 000443 000343 000504 000203 000204 000943 000886 000337 000146 000778 001051 000121 000196 000647 001182 000972 000750 000568 000585 000925 001295 000280 000741 000360 000969 000042 000087 000975 000877 000942 000998 000356 000697 000403 000208 000514 000401 000248 000758 001192 001282 000876 000764 000445 000865 000564 000149 000584 001019 001017 000308 001318 000018 000056 000232 000777 000328 001217 001219 000577 000383 001276 000929 000354 000029 001118 000006 000105 000486 000074 000275 001189 000612 000008 001093 000034 000949 000039 001138 000677 001198 000696 001230 000590 001243 000323 001032 000488 000413 000700 000854 000217 000101 000518 000578 000950 000031 000888 000256 000419 000035 000872 001147 001033 000991 001113 000884 001292 000235 000185 000404 001071 000340 000387 000990 000254 000953 000770 000241 000559 000357 000905 000068 000097 000108 000618 000982 000743 000279 000824 001058 000681 001196 001275 000760 000089 000054 000624 000987 001179 000338 000206 000177 000565 001134 000130 000952 001082 000136 000551 000288 000619 000852 001165 000321 001231 000501 000393 000372 000512 000424 001211 000977 000580 000626 000218 000451 000566 000165 000964 000407 000517 000572 001027 001106 000662 000102 000776 000160 000072 001062 000899 001091 000024 000304 000753 000594 000769 000464 000044 000007 000438 001313 000069 001187 000993 000161 000505 001290 000546 000258 001064 001228 001115 000674 001293 001269 001172 001247 000667 000641 000172 000503 000181 000726 000033 000437 001263 000989 000376 000959 000491 001090 000530 001104 000796 000290 000179 000868 000264 000327 000657 000442 001035 001136 000251 000342 000098 001166 000829 001049 000104 000402 000339 000427 000078 001014 000139 001152 000479 000175 000435 001072 000246 000414 000638 000904 001321 000396 000094 000805 000005 001316 001046 000586 001226 000079 000725 001296 000827 000103 001099 001205 000368 000278 000190 000544 001260 000997 000431 000919 001197 001030 001173 000454 001157 000164 001037 000077 000642 000828 000675 000702 000310 000797 000739 000123 001270 001170 001117 000857 000477 000180 001130 000958 000855 001109 001009 000199 001203 000676 000132 000441 001261 000878 000283 000408 000244 000322 001105 000932 XDU514 XDU646 XDU904 XDU660 XDU347 XDU962 XDU92 XDU838 XDU907 XDU496 XDU83 XDU606 XDU307 XDU138 XDU357 XDU993 XDU693 XDU493 XDU891 XDU410 XDU288 XDU562 XDU849 XDU23 XDU199 XDU370 XDU537 XDU871 XDU656 XDU331 XDU328 XDU403 XDU230 XDU529 XDU229 XDU476 XDU792 XDU412 XDU689 XDU51 XDU532 XDU356 XDU303 XDU161 XDU879 XDU867 XDU773 XDU323 XDU836 XDU236 XDU749 XDU807 XDU72 XDU128 XDU822 XDU480 XDU270 XDU651 XDU815 XDU30 XDU548 XDU386 XDU555 XDU122 XDU798 XDU264 XDU725 XDU806 XDU440 XDU332 XDU875 XDU325 XDU486 XDU659 XDU835 XDU335 XDU330 XDU132 XDU89 XDU580 XDU8 XDU4 XDU280 XDU895 XDU869 XDU799 XDU419 XDU772 XDU896 XDU604 XDU116 XDU85 XDU91 XDU158 XDU130 XDU611 XDU98 XDU299 XDU114 XDU923 XDU94 XDU800 XDU934 XDU998 XDU338 XDU959 XDU712 XDU754 XDU636 XDU624 XDU918 XDU26 XDU559 XDU761 XDU909 XDU340 XDU262 XDU964 XDU29 XDU443 XDU929 XDU739 XDU654 XDU858 XDU551 XDU365 XDU665 XDU705 XDU434 XDU105 XDU887 XDU910 XDU261 XDU360 XDU912 XDU894 XDU727 XDU42 XDU556 XDU613 XDU285 XDU668 XDU592 XDU341 XDU942 XDU982 XDU191 XDU528 XDU720 XDU601 XDU531 XDU21 XDU275 XDU232 XDU344 XDU301 XDU977 XDU390 XDU938 XDU975 XDU688 XDU979 XDU58 XDU494 XDU988 XDU826 XDU571 XDU1 XDU779 XDU633 XDU987 XDU449 XDU313 XDU153 XDU718 XDU978 XDU538 XDU950 XDU740 XDU986 XDU780 XDU965 XDU900 XDU423 XDU194 XDU490 XDU73 XDU675 XDU254 XDU547 XDU983 XDU227 XDU68 XDU445 XDU903 XDU652 XDU642 XDU599 XDU64 XDU221 XDU291 XDU460 XDU623 XDU766 XDU680 XDU714 XDU25 XDU19 XDU65 XDU671 XDU333 XDU375 XDU550 XDU790 XDU397 XDU497 XDU645 XDU470 XDU913 XDU956 XDU218 XDU380 XDU487 XDU66 XDU324 XDU612 XDU384 XDU544 XDU144 XDU542 XDU248 XDU461 XDU148 XDU653 XDU336 XDU866 XDU456 XDU540 XDU351 XDU112 XDU272 XDU53 XDU515 XDU699 XDU417 XDU639 XDU342 XDU31 XDU448 XDU292 XDU728 XDU180 XDU149 XDU706 XDU973 XDU320 XDU625 XDU811 XDU462 XDU388 XDU638 XDU90 XDU109 XDU722 XDU162 XDU972 XDU767 XDU349 XDU263 XDU465 XDU576 XDU507 XDU644 XDU587 XDU255 XDU326 XDU500 XDU586 XDU524 XDU765 XDU890 XDU960 XDU594 XDU80 XDU14 XDU569 XDU953 XDU884 XDU282 XDU387 XDU579 XDU260 XDU252 XDU971 XDU905 XDU994 XDU36 XDU266 XDU405 XDU208 XDU207 XDU723 XDU35 XDU590 XDU678 XDU629 XDU939 XDU804 XDU948 XDU24 XDU967 XDU293 XDU545 XDU901 XDU290 XDU989 XDU322 XDU707 XDU188 XDU582 XDU810 XDU439 XDU300 XDU237 XDU457 XDU433 XDU831 XDU917 XDU677 XDU82 XDU561 XDU413 XDU834 XDU368 XDU658 XDU898 XDU173 XDU34 XDU467 XDU2 XDU265 XDU735 XDU454 XDU163 XDU81 XDU74 XDU140 XDU166 XDU478 XDU61 XDU880 XDU841 XDU565 XDU377 XDU839 XDU691 XDU607 XDU530 XDU844 XDU847 XDU472 XDU882 XDU17 XDU619 XDU12 XDU736 XDU859 XDU186 XDU985 XDU389 XDU921 XDU355 XDU539 XDU141 XDU916 XDU135 XDU519 XDU621 XDU435 XDU760 XDU783 XDU591 XDU628 XDU464 XDU649 XDU198 XDU560 XDU372 XDU553 XDU458 XDU183 XDU650 XDU622 XDU825 XDU471 XDU190 XDU414 XDU44 XDU741 XDU635 XDU647 XDU573 XDU864 XDU618 XDU364 XDU543 XDU437 XDU502 XDU824 XDU952 XDU125 XDU641 XDU491 XDU201 XDU947 XDU444 XDU79 XDU518 XDU32 XDU283 XDU802 XDU483 XDU59 XDU477 XDU670 XDU234 XDU577 XDU969 XDU669 XDU995 XDU425 XDU506 XDU970 XDU681 XDU353 XDU676 XDU958 XDU782 XDU121 XDU363 XDU411 XDU690 XDU392 XDU536 XDU752 XDU210 XDU774 XDU350 XDU731 XDU930 XDU479 XDU602 XDU856 XDU96 XDU520 XDU13 XDU915 XDU106 XDU853 XDU308 XDU47 XDU769 XDU242 XDU899 XDU484 XDU581 XDU643 XDU827 XDU246 XDU119 XDU746 XDU100 XDU311 XDU512 XDU852 XDU509 XDU874 XDU686 XDU139 XDU928 XDU719 XDU296 XDU750 XDU713 XDU156 XDU634 XDU6 XDU821 XDU488 XDU420 XDU748 XDU126 XDU474 XDU273 XDU742 XDU438 XDU20 XDU27 XDU452 XDU541 XDU598 XDU949 XDU992 XDU513 XDU155 XDU228 XDU206 XDU69 XDU666 XDU860 XDU136 XDU617 XDU436 XDU716 XDU823 XDU38 XDU1000 XDU851 XDU627 XDU974 XDU717 XDU734 XDU796 XDU481 XDU990 XDU679 XDU764 XDU238 XDU848 XDU908 XDU418 XDU696 XDU378 XDU724 XDU632 XDU182 XDU76 XDU791 XDU830 XDU814 XDU306 XDU931 XDU154 XDU564 XDU383 XDU473 XDU84 XDU143 XDU18 XDU683 XDU495 XDU78 XDU840 XDU382 XDU385 XDU233 XDU220 XDU616 XDU655 XDU797 XDU854 XDU312 XDU924 XDU593 XDU150 XDU60 XDU951 XDU812 XDU608 XDU408 XDU184 XDU552 XDU172 XDU820 XDU584 XDU314 XDU702 XDU174 XDU379 XDU511 XDU914 XDU214 XDU315 XDU535 XDU305 XDU294 XDU71 XDU534 XDU133 XDU204 XDU991 XDU475 XDU870 XDU793 XDU585 XDU431 XDU786 XDU944 XDU102 XDU245 XDU857 XDU427 XDU738 XDU726 XDU568 XDU843 XDU131 XDU298 XDU498 XDU837 XDU243 XDU588 XDU832 XDU526 XDU710 XDU177 XDU310 XDU165 XDU603 XDU99 XDU22 XDU615 XDU566 XDU286 XDU703 XDU361 XDU795 XDU277 XDU508 XDU521 XDU297 XDU317 XDU861 XDU271 XDU318 XDU572 XDU247 XDU202 XDU946 XDU466 XDU732 XDU226 XDU583 XDU120 XDU926 XDU401 XDU687 XDU984 XDU819 XDU664 XDU400 XDU446 XDU453 XDU730 XDU362 XDU62 XDU175 XDU809 XDU430 XDU124 XDU346 XDU776 XDU605 XDU187 XDU337 XDU211 XDU684 XDU179 XDU981 XDU784 XDU701 XDU358 XDU768 XDU911 XDU235 XDU215 XDU145 XDU609 XDU281 XDU432 XDU196 XDU499 XDU250 XDU304 XDU600 XDU309 XDU171 XDU787 XDU595 XDU808 XDU7 XDU846 XDU428 XDU287 XDU729 XDU213 XDU828 XDU941 XDU395 XDU756 XDU897 XDU239 XDU610 XDU251 XDU373 XDU533 XDU95 XDU57 XDU945 XDU222 XDU168 XDU137 XDU961 XDU906 XDU937 XDU0 XDU770 XDU268 XDU963 XDU113 XDU771 XDU763 XDU339 XDU52 XDU737 XDU755 XDU159 XDU626 XDU16 XDU118 XDU77 XDU574 XDU402 XDU407 XDU88 XDU778 XDU391 XDU15 XDU940 XDU886 XDU359 XDU424 XDU721 XDU399 XDU345 XDU157 XDU107 XDU873 XDU865 XDU39 XDU893 XDU976 XDU695 XDU367 XDU700 XDU422 XDU936 XDU123 XDU503 XDU366 XDU101 XDU631 XDU276 XDU549 XDU212 XDU197 XDU640 XDU200 XDU37 XDU469 XDU522 XDU575 XDU256 XDU409 XDU152 XDU224 XDU86 XDU630 XDU980 XDU813 XDU70 XDU249 XDU396 XDU11 XDU327 XDU269 XDU284 XDU757 XDU348 XDU554 XDU127 XDU267 XDU751 XDU181 XDU468 XDU274 XDU55 XDU510 XDU170 XDU744 ================================================ FILE: metrics.py ================================================ import numpy as np import torch from skimage import measure class ROCMetric(): """Computes pixAcc and mIoU metric scores """ def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 # nclass :有几个类别 红外弱小目标检测只有一个类别 super(ROCMetric, self).__init__() self.nclass = nclass self.bins = bins self.tp_arr = np.zeros(self.bins + 1) self.pos_arr = np.zeros(self.bins + 1) self.fp_arr = np.zeros(self.bins + 1) self.neg_arr = np.zeros(self.bins + 1) self.class_pos = np.zeros(self.bins + 1) # self.reset() # 网络输入的结果和标签 计算两者之前的东西 def update(self, preds, labels): for iBin in range(self.bins + 1): # score_thresh = (iBin + 0.0) / self.bins score_thresh = -30 + iBin * (255 / self.bins) # print(iBin, "-th, score_thresh: ", score_thresh) i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) self.tp_arr[iBin] += i_tp self.pos_arr[iBin] += i_pos self.fp_arr[iBin] += i_fp self.neg_arr[iBin] += i_neg self.class_pos[iBin] += i_class_pos def get(self): tp_rates = self.tp_arr / (self.pos_arr + 0.001) # tp_rates = recall = TP/(TP+FN) fp_rates = self.fp_arr / (self.neg_arr + 0.001) # fp_rates = FP/(FP+TN) FP = self.fp_arr / (self.neg_arr + self.pos_arr) recall = self.tp_arr / (self.pos_arr + 0.001) # recall = TP/(TP+FN) precision = self.tp_arr / (self.class_pos + 0.001) # precision = TP/(TP+FP) return tp_rates, fp_rates, recall, precision, FP def reset(self): self.tp_arr = np.zeros([11]) self.pos_arr = np.zeros([11]) self.fp_arr = np.zeros([11]) self.neg_arr = np.zeros([11]) self.class_pos = np.zeros([11]) class mIoU(): def __init__(self): super(mIoU, self).__init__() self.reset() def update(self, preds, labels): correct, labeled = batch_pix_accuracy(preds, labels) inter, union = batch_intersection_union(preds, labels) self.total_correct += correct # 预测正确的像素数 self.total_label += labeled # GT目标的像素数 self.total_inter += inter # 交集 self.total_union += union # 并集 def get(self): pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) mIoU = IoU.mean() return float(pixAcc), mIoU def reset(self): self.total_inter = 0 self.total_union = 0 self.total_correct = 0 self.total_label = 0 class PD_FA(): def __init__(self, ): super(PD_FA, self).__init__() self.image_area_total = [] self.image_area_match = [] self.dismatch_pixel = 0 self.all_pixel = 0 self.PD = 0 self.target = 0 def update(self, preds, labels, size): predits = np.array((preds).cpu()).astype('int64') labelss = np.array((labels).cpu()).astype('int64') image = measure.label(predits, connectivity=2) coord_image = measure.regionprops(image) label = measure.label(labelss, connectivity=2) coord_label = measure.regionprops(label) self.target += len(coord_label) # 目标总数 直接就搞GT的连通域个数 self.image_area_total = [] # 图像中预测的区域列表 self.image_area_match = [] self.distance_match = [] self.dismatch = [] for K in range(len(coord_image)): area_image = np.array(coord_image[K].area) self.image_area_total.append(area_image) for i in range(len(coord_label)): # image 与 label 之间 根据中心点 进行连通域的确定 centroid_label = np.array(list(coord_label[i].centroid)) for m in range(len(coord_image)): centroid_image = np.array(list(coord_image[m].centroid)) distance = np.linalg.norm(centroid_image - centroid_label) area_image = np.array(coord_image[m].area) if distance < 3: self.distance_match.append(distance) self.image_area_match.append(area_image) del coord_image[m] # 匹配上一个之后就 清除一个 break self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] # 在image里面 但是不在label里面 self.dismatch_pixel += np.sum(self.dismatch) # Fa 虚警 self.all_pixel += size[0] * size[1] self.PD += len(self.distance_match) # 如果中心点之间距离在3一下 就算Pd 所以Pd 是匹配上了的目标的个数 def get(self): Final_FA = self.dismatch_pixel / self.all_pixel Final_PD = self.PD / self.target return Final_PD, float(Final_FA.cpu().detach().numpy()) def reset(self): self.FA = np.zeros([self.bins + 1]) self.PD = np.zeros([self.bins + 1]) def batch_pix_accuracy(output, target): if len(target.shape) == 3: target = np.expand_dims(target.float(), axis=1) elif len(target.shape) == 4: target = target.float() else: raise ValueError("Unknown target dimension") assert output.shape == target.shape, "Predict and Label Shape Don't Match" predict = (output > 0).float() # 将output 从 True Flase 转成 1 0 pixel_labeled = (target > 0).float().sum() # GF中 1的个数 pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum() # 预测对的个数 assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" return pixel_correct, pixel_labeled def batch_intersection_union(output, target): mini = 1 maxi = 1 nbins = 1 predict = (output > 0).float() if len(target.shape) == 3: target = np.expand_dims(target.float(), axis=1) elif len(target.shape) == 4: target = target.float() else: raise ValueError("Unknown target dimension") intersection = predict * ((predict == target).float()) area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi)) area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi)) area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi)) area_union = area_pred + area_lab - area_inter assert (area_inter <= area_union).all(), \ "Error: Intersection area should be smaller than Union area" return area_inter, area_union ================================================ FILE: model/Config.py ================================================ # -*- coding: utf-8 -*- # @Author : Shuai Yuan # @File : Config.py # @Software: PyCharm # coding=utf-8 import os import torch import time import ml_collections ########################################################################## # SCTrans configs ########################################################################## def get_SCTrans_config(): config = ml_collections.ConfigDict() config.transformer = ml_collections.ConfigDict() config.KV_size = 480 # KV_size = Q1 + Q2 + Q3 + Q4 config.transformer.num_heads = 4 config.transformer.num_layers = 4 config.patch_sizes = [16, 8, 4, 2] config.base_channel = 32 # base channel of U-Net config.n_classes = 1 # ********** unused ********** config.transformer.embeddings_dropout_rate = 0.1 config.transformer.attention_dropout_rate = 0.1 config.transformer.dropout_rate = 0 return config ================================================ FILE: model/SCTransNet.py ================================================ # -*- coding: utf-8 -*- # -*- coding: utf-8 -*- # @Author : Shuai Yuan # @File : SCTransNet.py # @Software: PyCharm # coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import math from torch.nn import Dropout, Softmax, Conv2d, LayerNorm from torch.nn.modules.utils import _pair import torch.nn as nn import torch import torch.nn.functional as F import ml_collections from einops import rearrange import numbers from thop import profile def get_CTranS_config(): config = ml_collections.ConfigDict() config.transformer = ml_collections.ConfigDict() config.KV_size = 480 # KV_size = Q1 + Q2 + Q3 + Q4 config.transformer.num_heads = 4 config.transformer.num_layers = 4 config.patch_sizes = [16, 8, 4, 2] config.base_channel = 32 # base channel of U-Net config.n_classes = 1 # ********** useless ********** config.transformer.embeddings_dropout_rate = 0.1 config.transformer.attention_dropout_rate = 0.1 config.transformer.dropout_rate = 0 return config class Channel_Embeddings(nn.Module): def __init__(self, config, patchsize, img_size, in_channels): super().__init__() img_size = _pair(img_size) patch_size = _pair(patchsize) n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) # 14 * 14 = 196 self.patch_embeddings = Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels)) self.dropout = Dropout(config.transformer["embeddings_dropout_rate"]) def forward(self, x): if x is None: return None x = self.patch_embeddings(x) return x class Reconstruct(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, scale_factor): super(Reconstruct, self).__init__() if kernel_size == 3: padding = 1 else: padding = 0 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) self.norm = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU(inplace=True) self.scale_factor = scale_factor # def forward(self, x, h, w): def forward(self, x): if x is None: return None x = nn.Upsample(scale_factor=self.scale_factor, mode='bilinear')(x) out = self.conv(x) out = self.norm(out) out = self.activation(out) return out # spatial-embedded Single-head Channel-cross Attention (SSCA) class Attention_org(nn.Module): def __init__(self, config, vis, channel_num): super(Attention_org, self).__init__() self.vis = vis self.KV_size = config.KV_size self.channel_num = channel_num self.num_attention_heads = 1 self.psi = nn.InstanceNorm2d(self.num_attention_heads) self.softmax = Softmax(dim=3) # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.mhead1 = nn.Conv2d(channel_num[0], channel_num[0] * self.num_attention_heads, kernel_size=1, bias=False) self.mhead2 = nn.Conv2d(channel_num[1], channel_num[1] * self.num_attention_heads, kernel_size=1, bias=False) self.mhead3 = nn.Conv2d(channel_num[2], channel_num[2] * self.num_attention_heads, kernel_size=1, bias=False) self.mhead4 = nn.Conv2d(channel_num[3], channel_num[3] * self.num_attention_heads, kernel_size=1, bias=False) self.mheadk = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False) self.mheadv = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False) self.q1 = nn.Conv2d(channel_num[0] * self.num_attention_heads, channel_num[0] * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=channel_num[0] * self.num_attention_heads // 2, bias=False) self.q2 = nn.Conv2d(channel_num[1] * self.num_attention_heads, channel_num[1] * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=channel_num[1] * self.num_attention_heads // 2, bias=False) self.q3 = nn.Conv2d(channel_num[2] * self.num_attention_heads, channel_num[2] * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=channel_num[2] * self.num_attention_heads // 2, bias=False) self.q4 = nn.Conv2d(channel_num[3] * self.num_attention_heads, channel_num[3] * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=channel_num[3] * self.num_attention_heads // 2, bias=False) self.k = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=self.KV_size * self.num_attention_heads, bias=False) self.v = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1, padding=1, groups=self.KV_size * self.num_attention_heads, bias=False) self.project_out1 = nn.Conv2d(channel_num[0], channel_num[0], kernel_size=1, bias=False) self.project_out2 = nn.Conv2d(channel_num[1], channel_num[1], kernel_size=1, bias=False) self.project_out3 = nn.Conv2d(channel_num[2], channel_num[2], kernel_size=1, bias=False) self.project_out4 = nn.Conv2d(channel_num[3], channel_num[3], kernel_size=1, bias=False) # ****************** useless *************************************** self.q1_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q1_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q1_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q1_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q2_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q2_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q2_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q2_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q3_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q3_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q3_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q3_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q4_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q4_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q4_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.q4_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) def forward(self, emb1, emb2, emb3, emb4, emb_all): b, c, h, w = emb1.shape q1 = self.q1(self.mhead1(emb1)) q2 = self.q2(self.mhead2(emb2)) q3 = self.q3(self.mhead3(emb3)) q4 = self.q4(self.mhead4(emb4)) k = self.k(self.mheadk(emb_all)) v = self.v(self.mheadv(emb_all)) # k, v = kv.chunk(2, dim=1) q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) q3 = rearrange(q3, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) q4 = rearrange(q4, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads) q1 = torch.nn.functional.normalize(q1, dim=-1) q2 = torch.nn.functional.normalize(q2, dim=-1) q3 = torch.nn.functional.normalize(q3, dim=-1) q4 = torch.nn.functional.normalize(q4, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) _, _, c1, _ = q1.shape _, _, c2, _ = q2.shape _, _, c3, _ = q3.shape _, _, c4, _ = q4.shape _, _, c, _ = k.shape attn1 = (q1 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size) attn2 = (q2 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size) attn3 = (q3 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size) attn4 = (q4 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size) attention_probs1 = self.softmax(self.psi(attn1)) attention_probs2 = self.softmax(self.psi(attn2)) attention_probs3 = self.softmax(self.psi(attn3)) attention_probs4 = self.softmax(self.psi(attn4)) out1 = (attention_probs1 @ v) out2 = (attention_probs2 @ v) out3 = (attention_probs3 @ v) out4 = (attention_probs4 @ v) out_1 = out1.mean(dim=1) out_2 = out2.mean(dim=1) out_3 = out3.mean(dim=1) out_4 = out4.mean(dim=1) out_1 = rearrange(out_1, 'b c (h w) -> b c h w', h=h, w=w) out_2 = rearrange(out_2, 'b c (h w) -> b c h w', h=h, w=w) out_3 = rearrange(out_3, 'b c (h w) -> b c h w', h=h, w=w) out_4 = rearrange(out_4, 'b c (h w) -> b c h w', h=h, w=w) O1 = self.project_out1(out_1) O2 = self.project_out2(out_2) O3 = self.project_out3(out_3) O4 = self.project_out4(out_4) weights = None return O1, O2, O3, O4, weights def to_3d(x): return rearrange(x, 'b c h w -> b (h w) c') def to_4d(x, h, w): return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class LayerNorm3d(nn.Module): def __init__(self, dim, LayerNorm_type): super(LayerNorm3d, self).__init__() if LayerNorm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): h, w = x.shape[-2:] return to_4d(self.body(to_3d(x)), h, w) class eca_layer_2d(nn.Module): def __init__(self, channel, k_size=3): super(eca_layer_2d, self).__init__() padding = k_size // 2 self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1) self.conv = nn.Sequential( nn.Conv1d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False), nn.Sigmoid() ) self.channel = channel self.k_size = k_size def forward(self, x): out = self.avg_pool(x) out = out.view(x.size(0), 1, x.size(1)) out = self.conv(out) out = out.view(x.size(0), x.size(1), 1, 1) return out * x # Complementary Feed-forward Network (CFN) class FeedForward(nn.Module): def __init__(self, dim, ffn_expansion_factor, bias): super(FeedForward, self).__init__() hidden_features = int(dim * ffn_expansion_factor) self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) self.dwconv3x3 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias) self.dwconv5x5 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, groups=hidden_features, bias=bias) self.relu3 = nn.ReLU() self.relu5 = nn.ReLU() self.project_out = nn.Conv2d(hidden_features * 2, dim, kernel_size=1, bias=bias) self.eca = eca_layer_2d(dim) def forward(self, x): x_3,x_5 = self.project_in(x).chunk(2, dim=1) x1_3 = self.relu3(self.dwconv3x3(x_3)) x1_5 = self.relu5(self.dwconv5x5(x_5)) x = torch.cat([x1_3, x1_5], dim=1) x = self.project_out(x) x = self.eca(x) return x # Spatial-channel Cross Transformer Block (SCTB) class Block_ViT(nn.Module): def __init__(self, config, vis, channel_num): super(Block_ViT, self).__init__() self.attn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias') self.attn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias') self.attn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias') self.attn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias') self.attn_norm = LayerNorm3d(config.KV_size, LayerNorm_type='WithBias') self.channel_attn = Attention_org(config, vis, channel_num) self.ffn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias') self.ffn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias') self.ffn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias') self.ffn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias') self.ffn1 = FeedForward(channel_num[0], ffn_expansion_factor=2.66, bias=False) self.ffn2 = FeedForward(channel_num[1], ffn_expansion_factor=2.66, bias=False) self.ffn3 = FeedForward(channel_num[2], ffn_expansion_factor=2.66, bias=False) self.ffn4 = FeedForward(channel_num[3], ffn_expansion_factor=2.66, bias=False) def forward(self, emb1, emb2, emb3, emb4): embcat = [] org1 = emb1 org2 = emb2 org3 = emb3 org4 = emb4 for i in range(4): var_name = "emb" + str(i + 1) tmp_var = locals()[var_name] if tmp_var is not None: embcat.append(tmp_var) emb_all = torch.cat(embcat, dim=1) cx1 = self.attn_norm1(emb1) if emb1 is not None else None cx2 = self.attn_norm2(emb2) if emb2 is not None else None cx3 = self.attn_norm3(emb3) if emb3 is not None else None cx4 = self.attn_norm4(emb4) if emb4 is not None else None emb_all = self.attn_norm(emb_all) # 1 196 960 cx1, cx2, cx3, cx4, weights = self.channel_attn(cx1, cx2, cx3, cx4, emb_all) cx1 = org1 + cx1 if emb1 is not None else None cx2 = org2 + cx2 if emb2 is not None else None cx3 = org3 + cx3 if emb3 is not None else None cx4 = org4 + cx4 if emb4 is not None else None org1 = cx1 org2 = cx2 org3 = cx3 org4 = cx4 x1 = self.ffn_norm1(cx1) if emb1 is not None else None x2 = self.ffn_norm2(cx2) if emb2 is not None else None x3 = self.ffn_norm3(cx3) if emb3 is not None else None x4 = self.ffn_norm4(cx4) if emb4 is not None else None x1 = self.ffn1(x1) if emb1 is not None else None x2 = self.ffn2(x2) if emb2 is not None else None x3 = self.ffn3(x3) if emb3 is not None else None x4 = self.ffn4(x4) if emb4 is not None else None x1 = x1 + org1 if emb1 is not None else None x2 = x2 + org2 if emb2 is not None else None x3 = x3 + org3 if emb3 is not None else None x4 = x4 + org4 if emb4 is not None else None return x1, x2, x3, x4, weights class Encoder(nn.Module): def __init__(self, config, vis, channel_num): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias') self.encoder_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias') self.encoder_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias') self.encoder_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias') for _ in range(config.transformer["num_layers"]): layer = Block_ViT(config, vis, channel_num) self.layer.append(copy.deepcopy(layer)) def forward(self, emb1, emb2, emb3, emb4): attn_weights = [] for layer_block in self.layer: emb1, emb2, emb3, emb4, weights = layer_block(emb1, emb2, emb3, emb4) if self.vis: attn_weights.append(weights) emb1 = self.encoder_norm1(emb1) if emb1 is not None else None emb2 = self.encoder_norm2(emb2) if emb2 is not None else None emb3 = self.encoder_norm3(emb3) if emb3 is not None else None emb4 = self.encoder_norm4(emb4) if emb4 is not None else None return emb1, emb2, emb3, emb4, attn_weights class ChannelTransformer(nn.Module): def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]): super().__init__() self.patchSize_1 = patchSize[0] self.patchSize_2 = patchSize[1] self.patchSize_3 = patchSize[2] self.patchSize_4 = patchSize[3] self.embeddings_1 = Channel_Embeddings(config, self.patchSize_1, img_size=img_size, in_channels=channel_num[0]) self.embeddings_2 = Channel_Embeddings(config, self.patchSize_2, img_size=img_size // 2, in_channels=channel_num[1]) self.embeddings_3 = Channel_Embeddings(config, self.patchSize_3, img_size=img_size // 4, in_channels=channel_num[2]) self.embeddings_4 = Channel_Embeddings(config, self.patchSize_4, img_size=img_size // 8, in_channels=channel_num[3]) self.encoder = Encoder(config, vis, channel_num) self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1, scale_factor=(self.patchSize_1, self.patchSize_1)) self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1, scale_factor=(self.patchSize_2, self.patchSize_2)) self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1, scale_factor=(self.patchSize_3, self.patchSize_3)) self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1, scale_factor=(self.patchSize_4, self.patchSize_4)) def forward(self, en1, en2, en3, en4): emb1 = self.embeddings_1(en1) emb2 = self.embeddings_2(en2) emb3 = self.embeddings_3(en3) emb4 = self.embeddings_4(en4) encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1, emb2, emb3, emb4) # (B, n_patch, hidden) x1 = self.reconstruct_1(encoded1) if en1 is not None else None x2 = self.reconstruct_2(encoded2) if en2 is not None else None x3 = self.reconstruct_3(encoded3) if en3 is not None else None x4 = self.reconstruct_4(encoded4) if en4 is not None else None x1 = x1 + en1 if en1 is not None else None x2 = x2 + en2 if en2 is not None else None x3 = x3 + en3 if en3 is not None else None x4 = x4 + en4 if en4 is not None else None return x1, x2, x3, x4, attn_weights def get_activation(activation_type): activation_type = activation_type.lower() if hasattr(nn, activation_type): return getattr(nn, activation_type)() else: return nn.ReLU() def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): layers = [] layers.append(CBN(in_channels, out_channels, activation)) for _ in range(nb_Conv - 1): layers.append(CBN(out_channels, out_channels, activation)) return nn.Sequential(*layers) class CBN(nn.Module): def __init__(self, in_channels, out_channels, activation='ReLU'): super(CBN, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.norm = nn.BatchNorm2d(out_channels) self.activation = get_activation(activation) def forward(self, x): out = self.conv(x) out = self.norm(out) return self.activation(out) class DownBlock(nn.Module): def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): super(DownBlock, self).__init__() self.maxpool = nn.MaxPool2d(2) self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) def forward(self, x): out = self.maxpool(x) return self.nConvs(out) class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class CCA(nn.Module): def __init__(self, F_g, F_x): super().__init__() self.mlp_x = nn.Sequential( Flatten(), nn.Linear(F_x, F_x)) self.mlp_g = nn.Sequential( Flatten(), nn.Linear(F_g, F_x)) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): avg_pool_x = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_x = self.mlp_x(avg_pool_x) avg_pool_g = F.avg_pool2d(g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) channel_att_g = self.mlp_g(avg_pool_g) channel_att_sum = (channel_att_x + channel_att_g) / 2.0 scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) x_after_channel = x * scale out = self.relu(x_after_channel) return out class UpBlock_attention(nn.Module): def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): super().__init__() self.up = nn.Upsample(scale_factor=2) self.coatt = CCA(F_g=in_channels // 2, F_x=in_channels // 2) self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) def forward(self, x, skip_x): up = self.up(x) skip_x_att = self.coatt(g=up, x=skip_x) x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension return self.nConvs(x) class Res_block(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(Res_block, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.LeakyReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) # self.fca = FCA_Layer(out_channels) if stride != 1 or out_channels != in_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm2d(out_channels)) else: self.shortcut = None def forward(self, x): residual = x if self.shortcut is not None: residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.relu(out) return out class SCTransNet(nn.Module): def __init__(self, config, n_channels=1, n_classes=1, img_size=256, vis=False, mode='train', deepsuper=True): super().__init__() self.vis = vis self.deepsuper = deepsuper print('Deep-Supervision:', deepsuper) self.mode = mode self.n_channels = n_channels self.n_classes = n_classes in_channels = config.base_channel # basic channel 64 block = Res_block self.pool = nn.MaxPool2d(2, 2) self.inc = self._make_layer(block, n_channels, in_channels) self.down_encoder1 = self._make_layer(block, in_channels, in_channels * 2, 1) # 64 128 self.down_encoder2 = self._make_layer(block, in_channels * 2, in_channels * 4, 1) # 64 128 self.down_encoder3 = self._make_layer(block, in_channels * 4, in_channels * 8, 1) # 64 128 self.down_encoder4 = self._make_layer(block, in_channels * 8, in_channels * 8, 1) # 64 128 self.mtc = ChannelTransformer(config, vis, img_size, channel_num=[in_channels, in_channels * 2, in_channels * 4, in_channels * 8], patchSize=config.patch_sizes) self.up_decoder4 = UpBlock_attention(in_channels * 16, in_channels * 4, nb_Conv=2) self.up_decoder3 = UpBlock_attention(in_channels * 8, in_channels * 2, nb_Conv=2) self.up_decoder2 = UpBlock_attention(in_channels * 4, in_channels, nb_Conv=2) self.up_decoder1 = UpBlock_attention(in_channels * 2, in_channels, nb_Conv=2) self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1, 1), stride=(1, 1)) if self.deepsuper: self.gt_conv5 = nn.Sequential(nn.Conv2d(in_channels * 8, 1, 1)) self.gt_conv4 = nn.Sequential(nn.Conv2d(in_channels * 4, 1, 1)) self.gt_conv3 = nn.Sequential(nn.Conv2d(in_channels * 2, 1, 1)) self.gt_conv2 = nn.Sequential(nn.Conv2d(in_channels * 1, 1, 1)) self.outconv = nn.Conv2d(5 * 1, 1, 1) def _make_layer(self, block, input_channels, output_channels, num_blocks=1): layers = [] layers.append(block(input_channels, output_channels)) for i in range(num_blocks - 1): layers.append(block(output_channels, output_channels)) return nn.Sequential(*layers) def forward(self, x): x1 = self.inc(x) # 64 224 224 x2 = self.down_encoder1(self.pool(x1)) # 128 112 112 x3 = self.down_encoder2(self.pool(x2)) # 256 56 56 x4 = self.down_encoder3(self.pool(x3)) # 512 28 28 d5 = self.down_encoder4(self.pool(x4)) # 512 14 14 # CCT f1 = x1 f2 = x2 f3 = x3 f4 = x4 # CCT x1, x2, x3, x4, att_weights = self.mtc(x1, x2, x3, x4) x1 = x1 + f1 x2 = x2 + f2 x3 = x3 + f3 x4 = x4 + f4 # Feature fusion d4 = self.up_decoder4(d5, x4) d3 = self.up_decoder3(d4, x3) d2 = self.up_decoder2(d3, x2) out = self.outc(self.up_decoder1(d2, x1)) # deep supervision if self.deepsuper: gt_5 = self.gt_conv5(d5) gt_4 = self.gt_conv4(d4) gt_3 = self.gt_conv3(d3) gt_2 = self.gt_conv2(d2) # 原始深监督 gt5 = F.interpolate(gt_5, scale_factor=16, mode='bilinear', align_corners=True) gt4 = F.interpolate(gt_4, scale_factor=8, mode='bilinear', align_corners=True) gt3 = F.interpolate(gt_3, scale_factor=4, mode='bilinear', align_corners=True) gt2 = F.interpolate(gt_2, scale_factor=2, mode='bilinear', align_corners=True) d0 = self.outconv(torch.cat((gt2, gt3, gt4, gt5, out), 1)) if self.mode == 'train': return (torch.sigmoid(gt5), torch.sigmoid(gt4), torch.sigmoid(gt3), torch.sigmoid(gt2), torch.sigmoid(d0), torch.sigmoid(out)) else: return torch.sigmoid(out) else: return torch.sigmoid(out) if __name__ == '__main__': config_vit = get_CTranS_config() model = SCTransNet(config_vit, mode='train', deepsuper=True) model = model inputs = torch.rand(1, 1, 256, 256) output = model(inputs) flops, params = profile(model, (inputs,)) print("-" * 50) print('FLOPs = ' + str(flops / 1000 ** 3) + ' G') print('Params = ' + str(params / 1000 ** 2) + ' M') ================================================ FILE: test.py ================================================ import argparse from torch.autograd import Variable from torch.utils.data import DataLoader from tqdm import tqdm import threading from dataset import * import time from collections import OrderedDict from model.SCTransNet import SCTransNet as SCTransNet # from loss import * import model.Config as config import numpy as np import torch from skimage import measure def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): predict = (output > score_thresh).float() if len(target.shape) == 3: print('????') # 加一个维度 使得target与 output的size一致 target = target.unsqueeze(dim=0) # target = np.expand_dims(target.float(), axis=1) target.to('cuda', torch.float) elif len(target.shape) == 4: target = target.float() else: raise ValueError("Unknown target dimension") # 现在predict中高于阈值的部分为全1矩阵 target是GT intersection = predict * ((predict == target).float()) tp = intersection.sum() # 对的预测为对的 fp = (predict * ((predict != target).float())).sum() # 错的预测为对的 虚警像素数 tn = ((1 - predict) * ((predict == target).float())).sum() # 错的预测为错的 fn = (((predict != target).float()) * (1 - predict)).sum() # 对的预测为错的 pos = tp + fn # 标签中 阳性的个数 neg = fp + tn # 标签中 阴性的个数 class_pos = tp + fp # 检测出的个数 return tp, pos, fp, neg, class_pos class SamplewiseSigmoidMetric(object): """Computes pixAcc and mIoU metric scores """ def __init__(self, nclass, score_thresh=0.5): self.nclass = nclass self.score_thresh = score_thresh self.lock = threading.Lock() self.reset() def update(self, preds, labels): """Updates the internal evaluation result. Parameters ---------- labels : 'NDArray' or list of `NDArray` The labels of the data. preds : 'NDArray' or list of `NDArray` Predicted values. """ def evaluate_worker(self, label, pred): inter_arr, union_arr = batch_intersection_union_n( pred, label, self.nclass, self.score_thresh) with self.lock: self.total_inter = np.append(self.total_inter, inter_arr) self.total_union = np.append(self.total_union, union_arr) if isinstance(preds, torch.Tensor): evaluate_worker(self, labels, preds) elif isinstance(preds, (list, tuple)): threads = [threading.Thread(target=evaluate_worker, args=(self, label, pred), ) for (label, pred) in zip(labels, preds)] for thread in threads: thread.start() for thread in threads: thread.join() def get(self): """Gets the current evaluation result. Returns ------- metrics : tuple of float pixAcc and mIoU """ IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) nIoU = IoU.mean() return nIoU def reset(self): """Resets the internal evaluation result to initial state.""" self.total_inter = np.array([]) self.total_union = np.array([]) self.total_correct = np.array([]) self.total_label = np.array([]) def batch_intersection_union_n(output, target, nclass, score_thresh): """nIoU""" mini = 1 maxi = 1 # nclass nbins = 1 # nclass outputnp = output.detach().cpu().numpy() # outputsig = F.sigmoid(output).detach().cpu().numpy() # outputsig = nd.sigmoid(output).asnumpy() predict = (outputnp > 0.5).astype('int64') # predict = predict.detach().cpu().numpy() # predict = (output.asnumpy() > 0).astype('int64') # P if len(target.shape) == 3: target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T elif len(target.shape) == 4: target = target.cpu().numpy().astype('int64') # T else: raise ValueError("Unknown target dimension") intersection = predict * (predict == target) # TP 交集 num_sample = intersection.shape[0] area_inter_arr = np.zeros(num_sample) area_pred_arr = np.zeros(num_sample) area_lab_arr = np.zeros(num_sample) area_union_arr = np.zeros(num_sample) for b in range(num_sample): # areas of intersection and union area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) area_inter_arr[b] = area_inter area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) area_pred_arr[b] = area_pred area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) area_lab_arr[b] = area_lab area_union = area_pred + area_lab - area_inter area_union_arr[b] = area_union assert (area_inter <= area_union).all(), \ "Intersection area should be smaller than Union area" return area_inter_arr, area_union_arr class ROCMetric05(): """Computes pixAcc and mIoU metric scores """ def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 # nclass :有几个类别 红外弱小目标检测只有一个类别 super(ROCMetric05, self).__init__() self.nclass = nclass self.bins = bins self.tp_arr = np.zeros(self.bins + 1) self.pos_arr = np.zeros(self.bins + 1) self.fp_arr = np.zeros(self.bins + 1) self.neg_arr = np.zeros(self.bins + 1) self.class_pos = np.zeros(self.bins + 1) # self.reset() # 网络输入的结果和标签 计算两者之前的东西 def update(self, preds, labels): for iBin in range(self.bins + 1): # score_thresh = (iBin + 0.0) / self.bins score_thresh = (0.0 + iBin) / self.bins # print(iBin, "-th, score_thresh: ", score_thresh) i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) self.tp_arr[iBin] += i_tp self.pos_arr[iBin] += i_pos self.fp_arr[iBin] += i_fp # 虚警像素数 self.neg_arr[iBin] += i_neg self.class_pos[iBin] += i_class_pos def get(self): tp_rates = self.tp_arr / (self.pos_arr + 0.001) # tp_rates = recall = TP/(TP+FN) fp_rates = self.fp_arr / (self.neg_arr + 0.001) # fp_rates = FP/(FP+TN) FP = self.fp_arr / (self.neg_arr + self.pos_arr) recall = self.tp_arr / (self.pos_arr + 0.001) # recall = TP/(TP+FN) precision = self.tp_arr / (self.class_pos + 0.001) # precision = TP/(TP+FP) f1_score = (2.0 * recall[5] * precision[5]) / (recall[5] + precision[5] + 0.00001) return tp_rates, fp_rates, recall, precision, FP, f1_score def reset(self): self.tp_arr = np.zeros([11]) self.pos_arr = np.zeros([11]) self.fp_arr = np.zeros([11]) self.neg_arr = np.zeros([11]) self.class_pos = np.zeros([11]) class mIoU(): def __init__(self): super(mIoU, self).__init__() self.reset() def update(self, preds, labels): correct, labeled = batch_pix_accuracy(preds, labels) # labeled: GT中目标的像素数目 correct:预测正确的像素数 inter, union = batch_intersection_union(preds, labels) self.total_correct += correct self.total_label += labeled self.total_inter += inter self.total_union += union def get(self): pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) mIoU = IoU.mean() return float(pixAcc), mIoU def reset(self): self.total_inter = 0 self.total_union = 0 self.total_correct = 0 self.total_label = 0 class PDFA(): def __init__(self, ): super(PDFA, self).__init__() self.image_area_total = [] self.image_area_match = [] self.dismatch_pixel = 0 self.all_pixel = 0 self.PD = 0 self.target = 0 def update(self, preds, labels, size): predits = np.array((preds).cpu()).astype('int64') labelss = np.array((labels).cpu()).astype('int64') image = measure.label(predits, connectivity=2) coord_image = measure.regionprops(image) label = measure.label(labelss, connectivity=2) coord_label = measure.regionprops(label) self.target += len(coord_label) self.image_area_total = [] self.image_area_match = [] self.distance_match = [] self.dismatch = [] for K in range(len(coord_image)): area_image = np.array(coord_image[K].area) self.image_area_total.append(area_image) for i in range(len(coord_label)): centroid_label = np.array(list(coord_label[i].centroid)) for m in range(len(coord_image)): centroid_image = np.array(list(coord_image[m].centroid)) distance = np.linalg.norm(centroid_image - centroid_label) area_image = np.array(coord_image[m].area) if distance < 3: self.distance_match.append(distance) self.image_area_match.append(area_image) del coord_image[m] break self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] self.dismatch_pixel += np.sum(self.dismatch) self.all_pixel += size[0] * size[1] self.PD += len(self.distance_match) def get(self): Final_FA = self.dismatch_pixel / self.all_pixel Final_PD = self.PD / self.target return Final_PD, float(Final_FA.cpu().detach().numpy()) def reset(self): self.FA = np.zeros([self.bins + 1]) self.PD = np.zeros([self.bins + 1]) def batch_pix_accuracy(output, target): if len(target.shape) == 3: target = np.expand_dims(target.float(), axis=1) elif len(target.shape) == 4: target = target.float() else: raise ValueError("Unknown target dimension") assert output.shape == target.shape, "Predict and Label Shape Don't Match" predict = (output > 0).float() pixel_labeled = (target > 0).float().sum() pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum() assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" return pixel_correct, pixel_labeled def batch_intersection_union(output, target): mini = 1 maxi = 1 nbins = 1 predict = (output > 0).float() if len(target.shape) == 3: target = np.expand_dims(target.float(), axis=1) elif len(target.shape) == 4: target = target.float() else: raise ValueError("Unknown target dimension") intersection = predict * ((predict == target).float()) area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi)) area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi)) area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi)) area_union = area_pred + area_lab - area_inter assert (area_inter <= area_union).all(), \ "Error: Intersection area should be smaller than Union area" return area_inter, area_union class PD_FA(): def __init__(self, ): super(PD_FA, self).__init__() self.image_area_total = [] self.image_area_match = [] self.dismatch_pixel = 0 self.all_pixel = 0 self.PD = 0 self.target = 0 def update(self, preds, labels, size): predits = np.array((preds).cpu()).astype('int64') labelss = np.array((labels).cpu()).astype('int64') image = measure.label(predits, connectivity=2) coord_image = measure.regionprops(image) label = measure.label(labelss, connectivity=2) coord_label = measure.regionprops(label) self.target += len(coord_label) # 目标总数 直接就搞GT的连通域个数 self.image_area_total = [] # 图像中预测的区域列表 self.image_area_match = [] self.distance_match = [] self.dismatch = [] for K in range(len(coord_image)): area_image = np.array(coord_image[K].area) self.image_area_total.append(area_image) for i in range(len(coord_label)): # image 与 label 之间 根据中心点 进行连通域的确定 centroid_label = np.array(list(coord_label[i].centroid)) for m in range(len(coord_image)): centroid_image = np.array(list(coord_image[m].centroid)) distance = np.linalg.norm(centroid_image - centroid_label) area_image = np.array(coord_image[m].area) if distance < 3: self.distance_match.append(distance) self.image_area_match.append(area_image) del coord_image[m] # 匹配上一个之后就 清除一个 break self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] # 在image里面 但是不在label里面 self.dismatch_pixel += np.sum(self.dismatch) # Fa 虚警个数 像素的虚警 # print(self.dismatch_pixel) self.all_pixel += size[0] * size[1] self.PD += len(self.distance_match) # 如果中心点之间距离在3一下 就算Pd 所以Pd 是匹配上了的目标的个数 def get(self): Final_FA = self.dismatch_pixel / self.all_pixel Final_PD = self.PD / self.target return Final_PD, float(Final_FA.cpu().detach().numpy()) def reset(self): self.FA = np.zeros([self.bins + 1]) self.PD = np.zeros([self.bins + 1]) os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' parser = argparse.ArgumentParser(description="PyTorch BasicIRSTD test") parser.add_argument('--ROC_thr', type=int, default=10, help='num') parser.add_argument("--model_names", default=['SCTrans'], type=list, help="model_name: 'ACM', 'Ours01', 'DNANet', 'ISNet', 'ACMNet', 'Ours01', 'ISTDU-Net', 'U-Net', 'RISTDnet'") parser.add_argument("--pth_dirs", default=['SIRST3/SCTransNet_NUAA_NUDT_IRSTD1K.pth.tar'], type=list) parser.add_argument("--dataset_dir", default=r'D:\05TGARS\Upload\datasets', type=str, help="train_dataset_dir") parser.add_argument("--dataset_names", default=['NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K'], type=list, help="dataset_name: 'NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K', 'SIRST3', 'NUDT-SIRST-Sea'") parser.add_argument("--img_norm_cfg", default=None, type=dict, help="specific a img_norm_cfg, default=None (using img_norm_cfg values of each dataset)") parser.add_argument("--save_img", default=False, type=bool, help="save image of or not") parser.add_argument("--save_img_dir", type=str, default=r'D:\SCI\01_02_SCI\Result/', help="path of saved image") parser.add_argument("--save_log", type=str, default=r'D:\05TGARS\Upload\log/', help="path of saved .pth") parser.add_argument("--threshold", type=float, default=0.5) global opt opt = parser.parse_args() def test(): test_set = TestSetLoader(opt.dataset_dir, opt.train_dataset_name, opt.test_dataset_name, opt.img_norm_cfg) test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False) # *************************固定阈值********************** # 计算mIOU 完全OK IOU = mIoU() # 计算nIOU 完全OK nIoU_metric = SamplewiseSigmoidMetric(nclass=1, score_thresh=0) # 计算PD_FA 完全OK eval_05 = PD_FA() ROC_05 = ROCMetric05(nclass=1, bins=10) config_vit = config.get_SCTrans_config() # CPU net = SCTransNet(config_vit, mode='test', deepsuper=True) state_dict = torch.load(opt.pth_dir, map_location='cpu') # # CUDA # net = SCTransNet(config_vit, mode='test', deepsuper=True).cuda() # state_dict = torch.load(opt.pth_dir) new_state_dict = OrderedDict() # for k, v in state_dict['state_dict'].items(): name = k[6:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 net.load_state_dict(new_state_dict) net.eval() tbar = tqdm(test_loader) with torch.no_grad(): for idx_iter, (img, gt_mask, size, img_dir) in enumerate(tbar): # img = Variable(img) # CPU pred = net.forward(img) pred = pred[:, :, :size[0], :size[1]] gt_mask = gt_mask[:, :, :size[0], :size[1]] # # CUDA: # pred = net.forward(img).cuda() # pred = pred[:, :, :size[0], :size[1]].cuda() # gt_mask = gt_mask[:, :, :size[0], :size[1]].cuda() # Fix threshold ########################################################## # IOU IOU.update((pred > 0.5), gt_mask) # 像素 # nIOU nIoU_metric.update(pred, gt_mask) # 像素 eval_05.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size) # 目标 ROC_05.update(pred, gt_mask) # save img if opt.save_img == True: img_save = transforms.ToPILImage()((pred[0, 0, :, :]).cpu()) if not os.path.exists(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name): os.makedirs(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name) img_save.save(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name + '/' + img_dir[0] + '.png') # 0.5 # IOU OK Good! pixAcc, mIOU = IOU.get() # # nIOU OK Good! nIoU = nIoU_metric.get() # # Pd Fa results2 = eval_05.get() # # # FP ture_positive_rate, false_positive_rate, recall, precision, FP, F1_score = ROC_05.get() print('pixAcc: %.4f| mIoU: %.4f | nIoU: %.4f | Pd: %.4f| Fa: %.4f |F1: %.4f' % (pixAcc * 100, mIOU * 100, nIoU * 100, results2[0] * 100, results2[1] * 1e+6, F1_score * 100)) if __name__ == '__main__': opt.f = open(opt.save_log + 'test_' + (time.ctime()).replace(' ', '_').replace(':', '_') + '.txt', 'w') if opt.pth_dirs == None: for i in range(len(opt.model_names)): opt.model_name = opt.model_names[i] print(opt.model_name) opt.f.write(opt.model_name + '_400.pth.tar' + '\n') for dataset_name in opt.dataset_names: opt.dataset_name = dataset_name opt.train_dataset_name = opt.dataset_name opt.test_dataset_name = opt.dataset_name print(dataset_name) opt.f.write(opt.dataset_name + '\n') opt.pth_dir = opt.save_log + opt.dataset_name + '/' + opt.model_name + '_400.pth.tar' test() print('\n') opt.f.write('\n') opt.f.close() else: for model_name in opt.model_names: for dataset_name in opt.dataset_names: for pth_dir in opt.pth_dirs: # if dataset_name in pth_dir and model_name in pth_dir: opt.test_dataset_name = dataset_name opt.model_name = model_name opt.train_dataset_name = pth_dir.split('/')[0] print(pth_dir) opt.f.write(pth_dir) print(opt.test_dataset_name) opt.f.write(opt.test_dataset_name + '\n') opt.pth_dir = opt.save_log + pth_dir test() print('\n') opt.f.write('\n') opt.f.close() ================================================ FILE: train.py ================================================ import argparse import time import os import cv2 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" from torch.autograd import Variable from torch.utils.data import DataLoader from dataset import * from metrics import * from utils import * import model.Config as config from torch.utils.tensorboard import SummaryWriter from model.SCTransNet import SCTransNet as SCTransNet parser = argparse.ArgumentParser(description="PyTorch BasicIRSTD train") parser.add_argument("--model_names", default=['SCTransNet'], type=list, help="'ACM', 'ALCNet', 'DNANet', 'ISNet', 'UIUNet', 'RDIAN', 'RISTDnet'") parser.add_argument("--dataset_names", default=['SIRST3'], type=list) # SIRST3: NUAA NUDT IRSTD-1K parser.add_argument("--optimizer_name", default='Adam', type=str, help="optimizer name: AdamW, Adam, Adagrad, SGD") parser.add_argument("--epochs", default=1000, type=int, help="optimizer name: AdamW, Adam, Adagrad, SGD") parser.add_argument("--begin_test", default=500, type=int) parser.add_argument("--every_test", default=1, type=int) parser.add_argument("--every_save_pth", default=1000, type=int) parser.add_argument("--every_print", default=10, type=int) parser.add_argument("--dataset_dir", default=r'./datasets') parser.add_argument("--batchSize", type=int, default=16, help="Training batch sizse") parser.add_argument("--patchSize", type=int, default=256, help="Training patch size") parser.add_argument("--save", default=r'./log', type=str, help="Save path of checkpoints") parser.add_argument("--log_dir", type=str, default="./otherlogs/SCTransNet", help='path of log files') parser.add_argument("--img_norm_cfg", default=None, type=dict) parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use") parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for test") parser.add_argument("--seed", type=int, default=42, help="Threshold for test") parser.add_argument("--resume", default=False, type=list, help="Resume from exisiting checkpoints (default: None)") global opt opt = parser.parse_args() seed_pytorch(opt.seed) config_vit = config.get_SCTrans_config() def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def train(): train_set = TrainSetLoader(dataset_dir=opt.dataset_dir, dataset_name=opt.dataset_name, patch_size=opt.patchSize, img_norm_cfg=opt.img_norm_cfg) train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) net = Net(model_name=opt.model_name, mode='train').cuda() net.apply(weights_init_kaiming) net.train() epoch_state = 0 total_loss_list = [] total_loss_epoch = [] if not os.path.exists(opt.log_dir): os.makedirs(opt.log_dir) writer = SummaryWriter(opt.log_dir) if opt.resume: # for resume_pth in opt.resume: # if opt.dataset_name in resume_pth and opt.model_name in resume_pth: ckpt = torch.load('XX\\UCT04_best.pth.tar') net.load_state_dict(ckpt['state_dict']) epoch_state = ckpt['epoch'] total_loss_list = ckpt['total_loss'] # for i in range(len(opt.scheduler_settings['step'])): # opt.scheduler_settings['step'][i] = opt.scheduler_settings['step'][i] - ckpt['epoch'] ### Default settings of SCTransNet if opt.optimizer_name == 'Adam': opt.optimizer_settings = {'lr': 0.001} opt.scheduler_name = 'CosineAnnealingLR' opt.scheduler_settings = {'epochs': opt.epochs, 'eta_min': 1e-5, 'last_epoch': -1} ### Default settings of DNANet if opt.optimizer_name == 'Adagrad': opt.optimizer_settings = {'lr': 0.05} opt.scheduler_name = 'CosineAnnealingLR' opt.scheduler_settings = {'epochs': opt.epochs, 'min_lr': 1e-5} ### Default settings of EGEUNet if opt.optimizer_name == 'AdamW': opt.optimizer_settings = {'lr': 0.001, 'betas': (0.9, 0.999), "eps": 1e-8, "weight_decay": 1e-2, "amsgrad": False} opt.scheduler_name = 'CosineAnnealingLR' opt.scheduler_settings = {'epochs': opt.epochs, 'T_max': 50, 'eta_min': 1e-5, 'last_epoch': -1} opt.nEpochs = opt.scheduler_settings['epochs'] optimizer, scheduler = get_optimizer(net, opt.optimizer_name, opt.scheduler_name, opt.optimizer_settings, opt.scheduler_settings) for idx_epoch in range(epoch_state, opt.nEpochs): net.train() results1 = (0, 0) results2 = (0, 0) for idx_iter, (img, gt_mask) in enumerate(train_loader): img, gt_mask = Variable(img).cuda(), Variable(gt_mask).cuda() if img.shape[0] == 1: continue preds = net.forward(img) loss = net.loss(preds, gt_mask) total_loss_epoch.append(loss.detach().cpu()) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if (idx_epoch + 1) % opt.every_print == 0: total_loss_list.append(float(np.array(total_loss_epoch).mean())) print(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f, lr---%f,' % (idx_epoch + 1, total_loss_list[-1], scheduler.get_last_lr()[0])) opt.f.write(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f,\n' % (idx_epoch + 1, total_loss_list[-1])) total_loss_epoch = [] # Log the scalar values writer.add_scalar('loss', total_loss_list[-1], idx_epoch + 1) writer.add_scalar('lr', scheduler.get_last_lr()[0], idx_epoch + 1) # 500 if (idx_epoch + 1) >= opt.begin_test and (idx_epoch + 1) % opt.every_test == 0: test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg) test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False) net.eval() with torch.no_grad(): eval_mIoU = mIoU() eval_PD_FA = PD_FA() test_loss = [] for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader): img = Variable(img).cuda() pred = net.forward(img) if isinstance(pred, tuple): pred = pred[-1] elif isinstance(pred, list): pred = pred[-1] else: pred = pred pred = pred[:, :, :size[0], :size[1]] gt_mask = gt_mask[:, :, :size[0], :size[1]] # if pred.size() != gt_mask.size(): # print('1111') loss = net.loss(pred, gt_mask.cuda()) test_loss.append(loss.detach().cpu()) eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu()) eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size) test_loss.append(float(np.array(test_loss).mean())) results1 = eval_mIoU.get() results2 = eval_PD_FA.get() writer.add_scalar('mIOU', results1[-1], idx_epoch + 1) writer.add_scalar('testloss', test_loss[-1], idx_epoch + 1) if (idx_epoch + 1) % opt.every_save_pth == 0: save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar' save_checkpoint({ 'epoch': idx_epoch + 1, 'state_dict': net.state_dict(), 'total_loss': total_loss_list, }, save_pth) test(save_pth) if idx_epoch == 0: best_mIOU = results1 best_Pd = results2 if results1[1] > best_mIOU[1]: best_mIOU = results1 best_Pd = results2 print('------save the best model epoch', opt.model_name,'_%d ------' % (idx_epoch + 1)) opt.f.write("the best model epoch \t" + str(idx_epoch + 1) + '\n') print("pixAcc, mIoU:\t" + str(best_mIOU)) print("testloss:\t" + str(test_loss[-1])) print("PD, FA:\t" + str(best_Pd)) opt.f.write("pixAcc, mIoU:\t" + str(best_mIOU) + '\n') opt.f.write("PD, FA:\t" + str(best_Pd) + '\n') save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '_' + 'best' + '.pth.tar' save_checkpoint({ 'epoch': idx_epoch + 1, 'state_dict': net.state_dict(), 'total_loss': total_loss_list, }, save_pth) # last epoch if (idx_epoch + 1) == opt.nEpochs and (idx_epoch + 1) % opt.every_save_pth != 0: save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar' save_checkpoint({ 'epoch': idx_epoch + 1, 'state_dict': net.state_dict(), 'total_loss': total_loss_list, }, save_pth) test(save_pth) def test(save_pth): test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg) test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False) net = Net(model_name=opt.model_name, mode='test').cuda() ckpt = torch.load(save_pth) net.load_state_dict(ckpt['state_dict']) net.eval() with torch.no_grad(): eval_mIoU = mIoU() eval_PD_FA = PD_FA() test_loss_a = [] for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader): img = Variable(img).cuda() pred = net.forward(img) if pred.size() != gt_mask.size(): print('1111') pred = pred[:, :, :size[0], :size[1]] gt_mask = gt_mask[:, :, :size[0], :size[1]] loss = net.loss(pred, gt_mask.cuda()) test_loss_a.append(loss.detach().cpu()) eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu()) eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size) test_loss_a.append(float(np.array(test_loss_a).mean())) results1 = eval_mIoU.get() results2 = eval_PD_FA.get() print('== == == == == == == ', opt.model_name, ' == == == == == == ==') print("pixAcc, mIoU:\t" + str(results1)) print("testloss:\t" + str(test_loss_a[-1])) print("PD, FA:\t" + str(results2)) opt.f.write("pixAcc, mIoU:\t" + str(results1) + '\n') opt.f.write("PD, FA:\t" + str(results2) + '\n') def save_checkpoint(state, save_path): if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) torch.save(state, save_path) return save_path class Net(nn.Module): def __init__(self, model_name, mode): super(Net, self).__init__() self.model_name = model_name # ************************************************loss*************************************************# self.cal_loss = nn.BCELoss(size_average=True) if model_name == 'SCTransNet': if mode == 'train': self.model = SCTransNet(config_vit, mode='train', deepsuper=True) else: self.model = SCTransNet(config_vit, mode='test', deepsuper=True) def forward(self, img): return self.model(img) def loss(self, preds, gt_masks): if isinstance(preds, list): loss_total = 0 for i in range(len(preds)): pred = preds[i] gt_mask = gt_masks[i] loss = self.cal_loss(pred, gt_mask) loss_total = loss_total + loss return loss_total / len(preds) elif isinstance(preds, tuple): a = [] for i in range(len(preds)): pred = preds[i] loss = self.cal_loss(pred, gt_masks) a.append(loss) loss_total = a[0] + a[1] + a[2] + a[3] + a[4] + a[5] return loss_total else: loss = self.cal_loss(preds, gt_masks) return loss if __name__ == '__main__': for dataset_name in opt.dataset_names: opt.dataset_name = dataset_name for model_name in opt.model_names: opt.model_name = model_name if not os.path.exists(opt.save): os.makedirs(opt.save) opt.f = open(opt.save + '/' + opt.dataset_name + '_' + opt.model_name + '_' + (time.ctime()).replace(' ', '_').replace( ':', '_') + '.txt', 'w') print(opt.dataset_name + '\t' + opt.model_name) train() print('\n') opt.f.close() ================================================ FILE: utils.py ================================================ import torch import numpy as np from PIL import Image from torchvision import transforms from torch.utils.data.dataset import Dataset import random import matplotlib.pyplot as plt import cv2 import numpy as np import os import math import torch.nn as nn from skimage import measure from warmup_scheduler import GradualWarmupScheduler import torch.nn.functional as F import os from torch.nn import init os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' def seed_pytorch(seed=42): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def weights_init_xavier(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1 and classname.find('SplAtConv2d') == -1: init.xavier_normal(m.weight.data) # def weights_init_xavier(m): # classname = m.__class__.__name__ # if classname.find('Conv2d') != -1: # # init.kaiming_normal_(m.weight.data,a=0, mode='fan_in', nonlinearity='leaky_relu') # init.xavier_normal(m.weight.data) def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) class Get_gradient_nopadding(nn.Module): def __init__(self): super(Get_gradient_nopadding, self).__init__() kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]] kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]] kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda() self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda() def forward(self, x): x0 = x[:, 0] x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1) x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1) x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) return x0 def random_crop(img, mask, patch_size, pos_prob=None): h, w = img.shape if min(h, w) < patch_size: img = np.pad(img, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)), mode='constant') # 将不足 256的一边填充至256 mask = np.pad(mask, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)), mode='constant') # label 与image 进行相同的变换 h, w = img.shape while 1: h_start = random.randint(0, h - patch_size) h_end = h_start + patch_size w_start = random.randint(0, w - patch_size) w_end = w_start + patch_size img_patch = img[h_start:h_end, w_start:w_end] mask_patch = mask[h_start:h_end, w_start:w_end] if pos_prob == None or random.random() > pos_prob: break elif mask_patch.sum() > 0: break return img_patch, mask_patch def Normalized(img, img_norm_cfg): return (img - img_norm_cfg['mean']) / img_norm_cfg['std'] def Denormalization(img, img_norm_cfg): return img * img_norm_cfg['std'] + img_norm_cfg['mean'] def get_img_norm_cfg(dataset_name, dataset_dir): if dataset_name == 'NUAA-SIRST': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'NUDT-SIRST': img_norm_cfg = dict(mean=107.80905151367188, std=33.02274703979492) elif dataset_name == 'IRSTD-1K': img_norm_cfg = dict(mean=87.4661865234375, std=39.71953201293945) elif dataset_name == 'SIRST2': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'SIRST3': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'NUDT-SIRST-Sea': img_norm_cfg = dict(mean=43.62403869628906, std=18.91838264465332) elif dataset_name == 'SIRST4': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'SIRST5': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'SIRST6': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'SIRST7': img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406) elif dataset_name == 'IRDST-real': img_norm_cfg = {'mean': 101.54053497314453, 'std': 56.49856185913086} else: with open(dataset_dir + '/' + dataset_name + '/img_idx/train_' + dataset_name + '.txt', 'r') as f: train_list = f.read().splitlines() with open(dataset_dir + '/' + dataset_name + '/img_idx/test_' + dataset_name + '.txt', 'r') as f: test_list = f.read().splitlines() img_list = train_list + test_list img_dir = dataset_dir + '/' + dataset_name + '/images/' mean_list = [] std_list = [] for img_pth in img_list: try: img = Image.open((img_dir + img_pth).replace('//', '/') + '.png').convert('I') except: try: img = Image.open((img_dir + img_pth).replace('//', '/') + '.jpg').convert('I') except: img = Image.open((img_dir + img_pth).replace('//', '/') + '.bmp').convert('I') img = np.array(img, dtype=np.float32) mean_list.append(img.mean()) std_list.append(img.std()) img_norm_cfg = dict(mean=float(np.array(mean_list).mean()), std=float(np.array(std_list).mean())) return img_norm_cfg def get_optimizer(net, optimizer_name, scheduler_name, optimizer_settings, scheduler_settings): if optimizer_name == 'Adam': optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr']) if optimizer_name == 'Adamweight': optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'], weight_decay=1e-3) elif optimizer_name == 'Adagrad': optimizer = torch.optim.Adagrad(net.parameters(), lr=optimizer_settings['lr']) elif optimizer_name == 'SGD': optimizer = torch.optim.SGD(net.parameters(), lr=optimizer_settings['lr'], momentum=0.9, weight_decay=scheduler_settings['weight_decay']) # elif optimizer_name == 'AdamW': # optimizer = torch.optim.AdamW(net.parameters(), lr=optimizer_settings['lr'], betas=optimizer_settings['betas'], # eps=optimizer_settings['eps'], weight_decay=optimizer_settings['weight_decay'], # amsgrad=optimizer_settings['amsgrad']) if scheduler_name == 'MultiStepLR': scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_settings['step'], gamma=scheduler_settings['gamma']) # elif scheduler_name == 'DNACosineAnnealingLR': # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], # eta_min=scheduler_settings['eta_min']) elif scheduler_name == 'CosineAnnealingLR': warmup_epochs = 10 scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs, eta_min=scheduler_settings['eta_min']) scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) elif scheduler_name == 'CosineAnnealingLRw50': warmup_epochs = 50 scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs, eta_min=scheduler_settings['eta_min']) scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) elif scheduler_name == 'CosineAnnealingLRw0': # warmup_epochs = 0 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min']) # scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs, # eta_min=1e-5) # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, # after_scheduler=scheduler_cosine) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['T_max'], # eta_min=scheduler_settings['eta_min'], # last_epoch=scheduler_settings['last_epoch']) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min']) return optimizer, scheduler def PadImg(img, times=32): h, w = img.shape if not h % times == 0: img = np.pad(img, ((0, (h // times + 1) * times - h), (0, 0)), mode='constant') if not w % times == 0: img = np.pad(img, ((0, 0), (0, (w // times + 1) * times - w)), mode='constant') return img ================================================ FILE: warmup_scheduler.py ================================================ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ReduceLROnPlateau class GradualWarmupScheduler(_LRScheduler): """ Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 在optimizer中会设置一个基础学习率base lr, 当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler 当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler Args: optimizer (Optimizer): Wrapped optimizer. multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. total_epoch: target learning rate is reached at total_epoch, gradually after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): self.multiplier = multiplier if self.multiplier < 1.: raise ValueError('multiplier should be greater thant or equal to 1.') self.total_epoch = total_epoch self.after_scheduler = after_scheduler self.finished = False super(GradualWarmupScheduler, self).__init__(optimizer) def get_lr(self): if self.last_epoch > self.total_epoch: if self.after_scheduler and (not self.finished): self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True # !这是很关键的一个环节,需要直接返回新的base-lr return [base_lr for base_lr in self.after_scheduler.base_lrs] if self.multiplier == 1.0: return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] def step_ReduceLROnPlateau(self, metrics, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning print('warmuping...') if self.last_epoch <= self.total_epoch: warmup_lr=None if self.multiplier == 1.0: warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): param_group['lr'] = lr else: if epoch is None: self.after_scheduler.step(metrics, None) else: self.after_scheduler.step(metrics,epoch - self.total_epoch) def step(self, epoch=None, metrics=None): if type(self.after_scheduler) != ReduceLROnPlateau: if self.finished and self.after_scheduler: if epoch is None: self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.total_epoch) self._last_lr = self.after_scheduler.get_last_lr() else: return super(GradualWarmupScheduler, self).step(epoch) else: self.step_ReduceLROnPlateau(metrics, epoch)