Repository: xdFai/SCTransNet
Branch: main
Commit: e0276283794e
Files: 11
Total size: 124.0 KB
Directory structure:
gitextract_03y9khf8/
├── README.md
├── dataset.py
├── datasets/
│ └── SIRST3/
│ └── img_idx/
│ ├── test_SIRST3.txt
│ └── train_SIRST3.txt
├── metrics.py
├── model/
│ ├── Config.py
│ └── SCTransNet.py
├── test.py
├── train.py
├── utils.py
└── warmup_scheduler.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# SCTransNet: Spatial-channel Cross Transformer Network for Infrared Small Target Detection [[Paper]](https://ieeexplore.ieee.org/document/10486932) [[Weight]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing)
Shuai Yuan, Hanlin Qin, Xiang Yan, Naveed Akhtar, Aimal Main, IEEE Transactions on Geoscience and Remote Sensing 2024.
# SCTransNet 是PRCV 2024、ICPR 2024 Track 1、ICPR 2024 Track 2 三项比赛冠军方案的 Baseline, 同时也是多个优胜算法的Baselines. [[Paper]](https://arxiv.org/abs/2408.09615)
# Bilibili 视频分享
https://www.bilibili.com/video/BV1kr421M7wx/
# 极市平台 推文分享
https://mp.weixin.qq.com/s/H7KLmtFX7j09f-Xc6X1FRw
# If the implementation of this repo is helpful to you, just star it!⭐⭐⭐
# Challenges and inspiration

# Structure


# Introduction
We present a Spatial-channel Cross Transformer Network (SCTransNet) to the IRSTD task. Experiments on both public (e.g., SIRST, NUDT-SIRST, IRSTD-1K) demonstrate the effectiveness of our method. Our main contributions are as follows:
1. We propose SCTransNet, leveraging spatial-channel cross transformer blocks (SCTB) to predict the context of targets and backgrounds in the deeper network layers.
2. A spatial-embedded single-head channel-cross attention (SSCA) module is utilized to foster semantic interactions across all feature levels and learn the long-range context.
3. We devise a novel complementary feed-forward network (CFN) by crossing spatial-channel information to enhance the semantic difference between the target and background.
## Usage
#### 1. Data
The **SIRST3** dataset, which combines **IRSTD-1K**, **NUDT-SIRST**, and **SIRST-v1**, is used to train SCTransNet.
* **SIRST-v1** [[download]](https://github.com/YimianDai/sirst) [[paper]](https://arxiv.org/pdf/2009.14530.pdf)
* **NUDT-SIRST** [[download]](https://github.com/YeRen123455/Infrared-Small-Target-Detection) [[paper]](https://ieeexplore.ieee.org/abstract/document/9864119)
* **IRSTD-1K** [[download dir]](https://github.com/RuiZhang97/ISNet) [[paper]](https://ieeexplore.ieee.org/document/9880295)
* Apologies for misnaming the **SIRST-v1** dataset as **NUAA-SIRST** in both the article and code. We will follow the original authors’ naming convention in future work.
* **Our project has the following structure:**
```
├──./datasets/
│ ├── IRSTD-1K
│ │ ├── images
│ │ │ ├── XDU0.png
│ │ │ ├── XDU1.png
│ │ │ ├── ...
│ │ ├── masks
│ │ │ ├── XDU0.png
│ │ │ ├── XDU1.png
│ │ │ ├── ...
│ │ ├── img_idx
│ │ │ ├── train_IRSTD-1K.txt
│ │ │ ├── test_IRSTD-1K.txt
│ ├── NUDT-SIRST
│ │ ├── images
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ ├── ...
│ │ ├── masks
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ ├── ...
│ │ ├── img_idx
│ │ │ ├── train_NUDT-SIRST.txt
│ │ │ ├── test_NUDT-SIRST.txt
│ ├── SIRSTv1 (~which is misnamed as NUAA-SIRST~)
│ │ ├── images
│ │ │ ├── Misc_1.png
│ │ │ ├── Misc_2.png
│ │ │ ├── ...
│ │ ├── masks
│ │ │ ├── Misc_1.png
│ │ │ ├── Misc_2.png
│ │ │ ├── ...
│ │ ├── img_idx
│ │ │ ├── train_NUAA-SIRST.txt
│ │ │ ├── test_NUAA-SIRST.txt
│ ├── SIRST3 (~The sum of SIRSTv1, NUDT-SIRST and IRSTD-1K~)
│ │ ├── images
│ │ │ ├── XDU0.png
│ │ │ ├── XDU1.png
│ │ │ ├── ...
│ │ ├── masks
│ │ │ ├── XDU0.png
│ │ │ ├── XDU1.png
│ │ │ ├── ...
│ │ ├── img_idx
│ │ │ ├── train_SIRST3.txt
│ │ │ ├── test_SIRST3.txt
```
##### 2. Train.
```bash
python train.py
```
#### 3. Test and demo.
权重文件的百度网盘链接:[https://pan.baidu.com/s/1_hlEaqnJI246GWN4N8k8wA?pwd=t28j](https://pan.baidu.com/s/1B0mANHXSfJaQjHr00XIwgQ?pwd=s7nh)
权重文件的谷歌云盘链接:https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing
```bash
python test.py
```
## Results and Trained Models
#### Qualitative Results

#### Quantitative Results on Mixed SIRSTv1, NUDT-SIRST, and IRSTD-1K. i.e, one weight for three Datasets.
| Model | mIoU (x10(-2)) | nIoU (x10(-2)) | F-measure (x10(-2))| Pd (x10(-2))| Fa (x10(-6))|
| ------------- |:-------------:|:-----:|:-----:|:-----:|:-----:|
| SIRSTv1 | 77.50 | 81.08 | 87.32 | 96.95 | 13.92 |
| NUDT-SIRST | 94.09 | 94.38 | 96.95 | 98.62 | 4.29 |
| IRSTD-1K | 68.03 | 68.15 | 80.96 | 93.27 | 10.74 |
| [[Weights]](https://drive.google.com/file/d/1Kxs2wKG2uq2YiGJOBGWoVz7B1-8DJoz3/view?usp=sharing)|
*This code is highly borrowed from [IRSTD-Toolbox](https://github.com/XinyiYing/BasicIRSTD). Thanks to Xinyi Ying.
*This code is highly borrowed from [UCTransNet](https://github.com/McGregorWwww/UCTransNet). Thanks to Haonan Wang.
*The overall repository style is highly borrowed from [DNA-Net](https://github.com/YeRen123455/Infrared-Small-Target-Detection). Thanks to Boyang Li.
## Citation
If you find the code useful, please consider citing our paper using the following BibTeX entry.
```
@ARTICLE{SCTransNet,
author={Yuan, Shuai and Qin, Hanlin and Yan, Xiang and Akhtar, Naveed and Mian, Ajmal},
journal={IEEE Transactions on Geoscience and Remote Sensing},
title={SCTransNet: Spatial-Channel Cross Transformer Network for Infrared Small Target Detection},
year={2024},
volume={62},
number={},
pages={1-15},
keywords={Semantics;Transformers;Decoding;Feature extraction;Task analysis;Object detection;Visualization;Convolutional neural network (CNN);cross-attention;deep learning;infrared small target detection (IRSTD);transformer},
doi={10.1109/TGRS.2024.3383649}}
@article{SP-KAN,
title = {SP-KAN: Sparse-sine perception Kolmogorov–Arnold networks for infrared small target detection},
journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
volume = {234},
pages = {1-19},
year = {2026},
issn = {0924-2716},
doi = {https://doi.org/10.1016/j.isprsjprs.2026.02.019},
url = {https://www.sciencedirect.com/science/article/pii/S0924271626000705},
author = {Shuai Yuan and Yu Liu and Xiaopei Zhang and Xiang Yan and Hanlin Qin and Naveed Akhtar},
}
```
## Contact
**Welcome to raise issues or email to [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our SCTransNet.**
================================================
FILE: dataset.py
================================================
from utils import *
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class TrainSetLoader(Dataset):
def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):
super(TrainSetLoader).__init__()
self.dataset_name = dataset_name
self.dataset_dir = dataset_dir + '/' + dataset_name
self.patch_size = patch_size
with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:
self.train_list = f.read().splitlines()
if img_norm_cfg == None:
self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)
else:
self.img_norm_cfg = img_norm_cfg
self.tranform = augumentation()
def __getitem__(self, idx):
try:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(
'I') # read image base on version ”I“
# img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))
except:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))
img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize
mask = np.array(mask, dtype=np.float32) / 255.0
if len(mask.shape) > 2:
mask = mask[:, :, 0]
# rnd_bn = np.random.normal(0, 0.03)#0.03
# img += rnd_bn
#
# minm = img.min()
# rng = img.max() - minm
# gamma = np.random.uniform(0.5, 1.6)
# x=((img - minm) / rng)
# img = np.power(x, gamma)
# img = img * rng + minm
img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256
img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强
img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维
img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor
mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor
return img_patch, mask_patch
def __len__(self):
return len(self.train_list)
class TrainSetLoader02(Dataset):
def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):
super(TrainSetLoader).__init__()
self.dataset_name = dataset_name
self.dataset_dir = dataset_dir + '/' + dataset_name
self.patch_size = patch_size
with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:
self.train_list = f.read().splitlines()
if img_norm_cfg == None:
self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)
else:
self.img_norm_cfg = img_norm_cfg
self.tranform = augumentation()
def __getitem__(self, idx):
try:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(
'I') # read image base on version ”I“
# img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))
except:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))
img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize
mask = np.array(mask, dtype=np.float32) / 255.0
if len(mask.shape) > 2:
mask = mask[:, :, 0]
rnd_bn = np.random.normal(0, 0.03)#0.03
img += rnd_bn
#
# minm = img.min()
# rng = img.max() - minm
# gamma = np.random.uniform(0.5, 1.6)
# x=((img - minm) / rng)
# img = np.power(x, gamma)
# img = img * rng + minm
img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256
img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强
img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维
img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor
mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor
return img_patch, mask_patch
def __len__(self):
return len(self.train_list)
class TrainSetLoader03(Dataset):
def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):
super(TrainSetLoader).__init__()
self.dataset_name = dataset_name
self.dataset_dir = dataset_dir + '/' + dataset_name
self.patch_size = patch_size
with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:
self.train_list = f.read().splitlines()
if img_norm_cfg == None:
self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)
else:
self.img_norm_cfg = img_norm_cfg
self.tranform = augumentation()
def __getitem__(self, idx):
try:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(
'I') # read image base on version ”I“
# img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))
except:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))
img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize
mask = np.array(mask, dtype=np.float32) / 255.0
if len(mask.shape) > 2:
mask = mask[:, :, 0]
# rnd_bn = np.random.normal(0, 0.03)#0.03
# img += rnd_bn
minm = img.min()
rng = img.max() - minm
gamma = np.random.uniform(0.5, 1.6)
x=((img - minm) / rng)
img = np.power(x, gamma)
img = img * rng + minm
img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256
img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强
img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维
img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor
mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor
return img_patch, mask_patch
def __len__(self):
return len(self.train_list)
class TrainSetLoader04(Dataset):
def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg=None):
super(TrainSetLoader).__init__()
self.dataset_name = dataset_name
self.dataset_dir = dataset_dir + '/' + dataset_name
self.patch_size = patch_size
with open(self.dataset_dir + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:
self.train_list = f.read().splitlines()
if img_norm_cfg == None:
self.img_norm_cfg = get_img_norm_cfg(dataset_name, dataset_dir)
else:
self.img_norm_cfg = img_norm_cfg
self.tranform = augumentation()
def __getitem__(self, idx):
try:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//', '/')).convert(
'I') # read image base on version ”I“
# img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.png').replace('//','/'))
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.png').replace('//', '/'))
except:
img = Image.open((self.dataset_dir + '/images/' + self.train_list[idx] + '.bmp').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.train_list[idx] + '.bmp').replace('//', '/'))
img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg) # convert PIL to numpy and normalize
mask = np.array(mask, dtype=np.float32) / 255.0
if len(mask.shape) > 2:
mask = mask[:, :, 0]
rnd_bn = np.random.normal(0, 0.03)#0.03
img += rnd_bn
minm = img.min()
rng = img.max() - minm
gamma = np.random.uniform(0.5, 1.6)
x=((img - minm) / rng)
img = np.power(x, gamma)
img = img * rng + minm
img_patch, mask_patch = random_crop(img, mask, self.patch_size, pos_prob=0.5) # 把短的一边先pad至256 把长的一边 随机裁出256 输出 256 256
img_patch, mask_patch = self.tranform(img_patch, mask_patch) # 数据翻转增强
img_patch, mask_patch = img_patch[np.newaxis, :], mask_patch[np.newaxis, :] # 升维
img_patch = torch.from_numpy(np.ascontiguousarray(img_patch)) # numpy 转tensor
mask_patch = torch.from_numpy(np.ascontiguousarray(mask_patch)) # numpy 转tensor
return img_patch, mask_patch
def __len__(self):
return len(self.train_list)
class TestSetLoader(Dataset):
def __init__(self, dataset_dir, train_dataset_name, test_dataset_name, img_norm_cfg=None):
super(TestSetLoader).__init__()
self.dataset_dir = dataset_dir + '/' + test_dataset_name
with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f:
# with open(r'D:\05TGARS\Upload\datasets\SIRST3\img_idx\val.txt', 'r') as f:
self.test_list = f.read().splitlines()
if img_norm_cfg == None:
self.img_norm_cfg = get_img_norm_cfg(train_dataset_name, dataset_dir)
else:
self.img_norm_cfg = img_norm_cfg
def __getitem__(self, idx):
try:
img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.png').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.png').replace('//', '/'))
except:
img = Image.open((self.dataset_dir + '/images/' + self.test_list[idx] + '.bmp').replace('//', '/')).convert('I')
mask = Image.open((self.dataset_dir + '/masks/' + self.test_list[idx] + '.bmp').replace('//', '/'))
img = Normalized(np.array(img, dtype=np.float32), self.img_norm_cfg)
mask = np.array(mask, dtype=np.float32) / 255.0
# if mask.shape == (416,608):
# print('111')
if len(mask.shape) > 2:
mask = mask[:, :, 0]
h, w = img.shape
img = PadImg(img)
mask = PadImg(mask)
img, mask = img[np.newaxis, :], mask[np.newaxis, :]
img = torch.from_numpy(np.ascontiguousarray(img))
mask = torch.from_numpy(np.ascontiguousarray(mask))
if img.size() != mask.size():
print('111')
return img, mask, [h, w], self.test_list[idx]
def __len__(self):
return len(self.test_list)
class EvalSetLoader(Dataset):
def __init__(self, dataset_dir, mask_pred_dir, test_dataset_name, model_name):
super(EvalSetLoader).__init__()
self.dataset_dir = dataset_dir
self.mask_pred_dir = mask_pred_dir
self.test_dataset_name = test_dataset_name
self.model_name = model_name
with open(self.dataset_dir + '/img_idx/test_' + test_dataset_name + '.txt', 'r') as f:
self.test_list = f.read().splitlines()
def __getitem__(self, idx):
mask_pred = Image.open(
(self.mask_pred_dir + self.test_dataset_name + '/' + self.model_name + '/' + self.test_list[idx] + '.png').replace('//', '/'))
mask_gt = Image.open(self.dataset_dir + '/masks/' + self.test_list[idx] + '.png')
mask_pred = np.array(mask_pred, dtype=np.float32) / 255.0
mask_gt = np.array(mask_gt, dtype=np.float32) / 255.0
if len(mask_pred.shape) == 3:
mask_pred = mask_pred[:, :, 0]
h, w = mask_pred.shape
mask_pred, mask_gt = mask_pred[np.newaxis, :], mask_gt[np.newaxis, :]
mask_pred = torch.from_numpy(np.ascontiguousarray(mask_pred))
mask_gt = torch.from_numpy(np.ascontiguousarray(mask_gt))
return mask_pred, mask_gt, [h, w]
def __len__(self):
return len(self.test_list)
class augumentation(object):
def __call__(self, input, target):
if random.random() < 0.5: # 水平反转
input = input[::-1, :]
target = target[::-1, :]
if random.random() < 0.5: # 垂直反转
input = input[:, ::-1]
target = target[:, ::-1]
if random.random() < 0.5: # 转置反转
input = input.transpose(1, 0)
target = target.transpose(1, 0)
return input, target
================================================
FILE: datasets/SIRST3/img_idx/test_SIRST3.txt
================================================
Misc_338
Misc_379
Misc_422
Misc_73
Misc_321
Misc_162
Misc_372
Misc_185
Misc_420
Misc_143
Misc_137
Misc_224
Misc_274
Misc_156
Misc_121
Misc_240
Misc_82
Misc_230
Misc_51
Misc_203
Misc_210
Misc_219
Misc_178
Misc_416
Misc_366
Misc_306
Misc_94
Misc_5
Misc_413
Misc_279
Misc_110
Misc_298
Misc_257
Misc_199
Misc_368
Misc_34
Misc_248
Misc_303
Misc_99
Misc_225
Misc_317
Misc_308
Misc_66
Misc_266
Misc_123
Misc_288
Misc_348
Misc_59
Misc_299
Misc_28
Misc_227
Misc_54
Misc_287
Misc_252
Misc_21
Misc_130
Misc_228
Misc_177
Misc_24
Misc_395
Misc_332
Misc_207
Misc_30
Misc_173
Misc_49
Misc_48
Misc_75
Misc_427
Misc_411
Misc_205
Misc_57
Misc_273
Misc_272
Misc_125
Misc_262
Misc_14
Misc_56
Misc_234
Misc_412
Misc_275
Misc_350
Misc_418
Misc_113
Misc_148
Misc_357
Misc_239
Misc_385
Misc_154
Misc_83
Misc_222
Misc_111
Misc_153
Misc_309
Misc_374
Misc_312
Misc_117
Misc_197
Misc_292
Misc_145
Misc_376
Misc_122
Misc_101
Misc_394
Misc_136
Misc_289
Misc_131
Misc_25
Misc_74
Misc_347
Misc_15
Misc_1
Misc_361
Misc_373
Misc_151
Misc_97
Misc_3
Misc_381
Misc_400
Misc_277
Misc_326
Misc_38
Misc_241
Misc_334
Misc_397
Misc_164
Misc_300
Misc_297
Misc_414
Misc_249
Misc_91
Misc_107
Misc_68
Misc_259
Misc_342
Misc_189
Misc_76
Misc_235
Misc_335
Misc_22
Misc_319
Misc_314
Misc_79
Misc_346
Misc_251
Misc_387
Misc_410
Misc_345
Misc_236
Misc_41
Misc_35
Misc_9
Misc_396
Misc_375
Misc_176
Misc_188
Misc_44
Misc_356
Misc_80
Misc_60
Misc_386
Misc_126
Misc_206
Misc_307
Misc_208
Misc_276
Misc_286
Misc_193
Misc_270
Misc_377
Misc_37
Misc_390
Misc_106
Misc_255
Misc_328
Misc_268
Misc_301
Misc_181
Misc_371
Misc_204
Misc_323
Misc_87
Misc_53
Misc_293
Misc_20
Misc_85
Misc_327
Misc_216
Misc_365
Misc_408
Misc_62
Misc_271
Misc_265
Misc_325
Misc_340
Misc_155
Misc_10
Misc_13
Misc_124
Misc_140
Misc_139
Misc_46
Misc_19
Misc_135
Misc_221
Misc_212
Misc_359
Misc_237
Misc_331
Misc_69
Misc_186
Misc_294
Misc_196
Misc_84
Misc_380
000541
000832
000710
000316
001142
000409
000226
001133
001255
000558
001000
000763
000385
001319
000591
000484
001236
000334
000915
000523
000765
000215
000045
001303
000644
001190
001026
001089
000720
000195
000519
001111
000922
000965
000080
000751
000281
000849
000495
001257
000966
001267
000091
000935
000314
000128
001169
000701
000903
000144
000613
000313
000552
000787
000536
000423
000389
001235
000615
000366
000399
001155
000625
000610
001050
001148
000440
000032
000812
000830
000771
000324
001161
000209
000645
000897
001036
000200
000708
000474
000840
000554
000011
000603
000560
000550
000498
000243
000433
000106
000036
000574
000837
000608
000341
000863
000907
000600
000361
000555
000636
000509
000799
000062
000297
000933
000374
000928
000910
000131
000191
000216
000373
000556
000417
000485
000557
000153
000955
000511
001038
001239
000986
000740
001018
001068
000902
001221
000510
000212
000026
001320
001074
000189
000335
000295
001265
001206
000637
000349
001122
001178
000944
001135
000113
000041
000575
000757
000882
000917
000909
000569
001167
001212
000801
000466
000449
000913
000607
001218
001096
000850
001151
000426
000908
000286
001280
000167
000332
000499
001066
000611
001188
001238
001195
000980
000194
000315
000391
000704
000930
000150
000687
000394
000988
000535
001278
000867
000371
000355
000722
000178
001215
000747
001162
000001
000405
001097
000506
001102
000326
000737
000302
000699
000715
000874
000713
001244
000589
001144
001213
000689
000266
000563
000395
000415
000749
000330
000300
000071
001264
000237
000973
001076
000666
001079
000532
000820
000207
000182
001176
001272
000047
000268
000906
000351
000633
000025
000957
000176
000220
000188
001006
000926
001129
000599
000684
000155
000703
001143
000620
000075
000839
001300
001186
000197
000811
000601
001040
000735
001029
000623
000436
001242
000813
000493
001053
000410
000936
000473
000478
000458
000709
001307
000781
001023
001004
000359
000287
000482
000941
000067
000284
000447
000894
000163
000059
001222
001268
000744
000921
000173
000378
000806
000629
000870
000860
000660
000664
000533
000617
001145
000961
000940
000365
001003
000305
001317
001314
000352
000898
000883
001309
000732
000336
000622
001232
000166
000392
000231
001281
001199
000893
000879
000084
001174
000377
001063
000628
000537
000772
000100
000043
000291
000918
000548
000729
000475
000731
000974
001047
000595
000508
001101
001248
001107
001127
001253
000920
000021
000742
000140
000320
001209
000434
000120
000711
000137
000249
000995
001084
001273
000845
000452
001220
000553
000003
000984
000086
000967
000788
000416
000138
000112
001114
001308
000053
000762
000333
001081
001287
000821
001302
000480
000775
000525
001262
000124
001069
000782
000490
000010
000465
001286
001060
000685
000717
000939
000027
000971
000727
000170
000187
000671
000976
000562
000651
000831
000397
000157
000682
001227
001010
000294
000159
001088
001241
001229
000289
000169
000135
000090
000456
000691
000382
000483
000528
000889
000060
001326
000520
000269
000706
001020
000502
000492
001153
000521
000496
000538
000009
001294
000542
000127
000192
001237
000406
001055
001251
001315
000151
001180
001250
000901
000816
000273
000545
000912
000583
000285
000678
000992
000386
000022
000573
001075
000927
001304
000219
001291
000786
000348
000861
000785
000875
001191
000853
001008
000817
000721
000807
000122
001083
000261
000588
000587
000880
001311
000835
001021
000714
000081
000420
001103
000142
000270
001160
001224
000792
000019
000470
000529
001141
000646
000656
000815
000885
000924
001254
000111
000598
001120
001087
000020
000214
000724
000582
001234
000983
000759
001149
001044
000665
000107
001322
001210
000448
000260
000579
000996
001246
000698
000783
000240
000862
000616
000051
000400
000455
000809
001013
000110
000634
000838
001184
001077
000694
000412
000836
000344
001240
000259
000439
000766
001202
000093
000507
000303
000262
000092
000890
000789
000784
000272
000186
000500
000734
000887
000476
000960
000202
001034
001214
000516
000808
000346
000680
001289
001042
000802
001139
000640
000803
000329
001123
000761
001028
000650
000133
000592
000148
000866
000230
001012
000055
000252
000421
000425
000596
000296
001325
000255
001092
000362
000450
001183
000046
000648
000370
001128
001175
000730
001154
000963
001098
000654
001112
000319
000398
001298
001057
001125
000012
000756
000141
000931
000581
000463
001259
000985
000076
000632
000162
000954
000779
000233
000614
001126
001288
001283
000790
000023
000380
000661
000923
000225
000156
000390
000597
000468
000606
000307
000630
001223
000038
000224
000673
001193
000016
000015
001110
001204
000683
000353
XDU189
XDU935
XDU672
XDU231
XDU818
XDU888
XDU146
XDU48
XDU492
XDU241
XDU195
XDU801
XDU104
XDU637
XDU996
XDU482
XDU406
XDU889
XDU558
XDU117
XDU777
XDU134
XDU223
XDU943
XDU762
XDU662
XDU54
XDU685
XDU167
XDU489
XDU505
XDU527
XDU817
XDU253
XDU193
XDU597
XDU151
XDU404
XDU596
XDU97
XDU321
XDU279
XDU93
XDU205
XDU9
XDU219
XDU674
XDU501
XDU316
XDU343
XDU885
XDU426
XDU485
XDU850
XDU516
XDU216
XDU160
XDU176
XDU504
XDU883
XDU244
XDU919
XDU781
XDU369
XDU398
XDU441
XDU75
XDU240
XDU805
XDU108
XDU709
XDU352
XDU747
XDU209
XDU845
XDU557
XDU775
XDU56
XDU657
XDU753
XDU788
XDU682
XDU794
XDU877
XDU421
XDU733
XDU546
XDU999
XDU5
XDU63
XDU966
XDU922
XDU789
XDU295
XDU863
XDU578
XDU743
XDU46
XDU115
XDU876
XDU932
XDU289
XDU855
XDU933
XDU517
XDU329
XDU3
XDU451
XDU694
XDU878
XDU259
XDU708
XDU442
XDU829
XDU833
XDU648
XDU381
XDU868
XDU803
XDU673
XDU415
XDU997
XDU667
XDU968
XDU169
XDU525
XDU164
XDU704
XDU711
XDU111
XDU354
XDU927
XDU758
XDU87
XDU697
XDU957
XDU49
XDU563
XDU954
XDU45
XDU429
XDU902
XDU302
XDU523
XDU41
XDU816
XDU785
XDU759
XDU872
XDU185
XDU881
XDU447
XDU129
XDU614
XDU920
XDU334
XDU257
XDU892
XDU103
XDU698
XDU862
XDU33
XDU416
XDU40
XDU715
XDU203
XDU589
XDU142
XDU50
XDU455
XDU620
XDU67
XDU371
XDU192
XDU28
XDU43
XDU661
XDU692
XDU463
XDU745
XDU258
XDU842
XDU459
XDU147
XDU319
XDU225
XDU178
XDU567
XDU925
XDU394
XDU110
XDU663
XDU376
XDU450
XDU10
XDU955
XDU374
XDU278
XDU393
XDU570
XDU217
================================================
FILE: datasets/SIRST3/img_idx/train_SIRST3.txt
================================================
Misc_119
Misc_64
Misc_90
Misc_364
Misc_250
Misc_351
Misc_39
Misc_313
Misc_179
Misc_344
Misc_421
Misc_398
Misc_417
Misc_95
Misc_339
Misc_426
Misc_269
Misc_316
Misc_419
Misc_144
Misc_149
Misc_146
Misc_31
Misc_58
Misc_4
Misc_264
Misc_283
Misc_284
Misc_150
Misc_220
Misc_133
Misc_77
Misc_70
Misc_425
Misc_195
Misc_304
Misc_329
Misc_65
Misc_167
Misc_174
Misc_202
Misc_157
Misc_96
Misc_320
Misc_369
Misc_109
Misc_16
Misc_40
Misc_295
Misc_147
Misc_247
Misc_423
Misc_152
Misc_100
Misc_263
Misc_352
Misc_233
Misc_190
Misc_392
Misc_281
Misc_358
Misc_163
Misc_132
Misc_405
Misc_159
Misc_12
Misc_367
Misc_172
Misc_401
Misc_138
Misc_104
Misc_86
Misc_160
Misc_242
Misc_7
Misc_305
Misc_243
Misc_399
Misc_363
Misc_61
Misc_129
Misc_330
Misc_134
Misc_315
Misc_180
Misc_244
Misc_63
Misc_391
Misc_42
Misc_404
Misc_29
Misc_238
Misc_285
Misc_214
Misc_93
Misc_253
Misc_402
Misc_50
Misc_291
Misc_128
Misc_267
Misc_115
Misc_337
Misc_370
Misc_158
Misc_114
Misc_388
Misc_170
Misc_354
Misc_36
Misc_424
Misc_336
Misc_393
Misc_229
Misc_108
Misc_105
Misc_406
Misc_2
Misc_324
Misc_47
Misc_200
Misc_187
Misc_33
Misc_72
Misc_384
Misc_120
Misc_322
Misc_360
Misc_192
Misc_112
Misc_142
Misc_403
Misc_169
Misc_223
Misc_213
Misc_161
Misc_256
Misc_141
Misc_78
Misc_296
Misc_6
Misc_258
Misc_231
Misc_52
Misc_183
Misc_362
Misc_102
Misc_88
Misc_343
Misc_341
Misc_118
Misc_165
Misc_280
Misc_17
Misc_290
Misc_67
Misc_382
Misc_191
Misc_166
Misc_8
Misc_45
Misc_415
Misc_349
Misc_98
Misc_127
Misc_184
Misc_310
Misc_198
Misc_254
Misc_211
Misc_103
Misc_232
Misc_218
Misc_89
Misc_201
Misc_11
Misc_168
Misc_215
Misc_383
Misc_333
Misc_245
Misc_55
Misc_27
Misc_226
Misc_116
Misc_378
Misc_355
Misc_302
Misc_209
Misc_32
Misc_23
Misc_261
Misc_182
Misc_282
Misc_409
Misc_260
Misc_194
Misc_407
Misc_175
Misc_278
Misc_26
Misc_246
Misc_217
Misc_311
Misc_43
Misc_353
Misc_81
Misc_18
Misc_318
Misc_389
Misc_171
Misc_71
Misc_92
001137
000345
000774
000593
001002
001024
001285
000649
000282
000037
001150
001124
001185
000718
000968
001216
000494
001015
000945
000446
000193
000746
000733
000621
001131
000085
000064
000312
000213
001297
001039
000066
000002
000843
000659
000298
000227
000951
000299
000937
000547
000948
000859
000058
000422
001140
000873
001249
000822
000745
000609
000881
001271
000851
000848
001016
001100
000004
001208
000099
000073
000049
000158
001279
000063
001181
000515
001054
000844
001177
000934
000663
000841
000205
000911
000211
000267
000379
000168
000301
001061
001299
000818
001070
001305
001095
000695
000375
000061
000669
000325
000864
000174
000347
000748
000946
001274
001059
000253
000679
001041
000999
000794
000576
001094
000635
000825
000013
000627
000487
001201
000457
000916
000030
000109
000693
000183
000467
001164
001258
000970
000263
000858
000668
000526
000736
001146
000247
000367
000050
000773
000318
001163
001159
001306
001310
000571
001031
000461
000652
000444
001194
001156
000014
000250
000134
001301
000768
000823
000040
000891
000561
001200
001233
000119
001324
000793
000293
000143
001011
000834
000095
000017
001116
000653
000755
000705
000723
000738
001266
000364
000791
000048
000228
000956
000570
000116
000350
000276
001022
000369
001225
000129
000947
000800
000154
001168
000602
000210
001284
000534
000814
000846
001132
000257
000688
001056
000430
000900
000471
000833
000171
000567
001043
000798
001085
000895
001121
000979
000363
000311
000229
000234
001073
000716
001005
001119
001327
000754
000245
000780
000277
001052
000856
000871
000531
000472
000083
000274
000147
000096
000152
001158
000453
000057
000411
000388
000082
000655
000962
001086
000539
000381
000527
000309
000065
000418
000306
000469
000088
001007
000690
000686
000428
000605
001067
001025
001245
001045
000767
000728
001065
000540
000658
000549
000522
000978
000317
000117
000184
000631
000459
000804
000819
001048
000271
000358
000670
000145
000118
000239
000126
000115
001171
000497
000847
000462
000604
000223
000810
000938
000460
001001
001108
000826
001252
000981
000429
001277
001080
000707
000221
000265
000672
000752
000481
000384
000236
000125
000892
000238
000639
000914
000432
000489
000513
000114
000719
000643
000994
000712
000242
001207
000201
001312
000524
000198
000543
000028
001323
000331
000222
000869
000692
000842
000795
001078
000896
000052
001256
000292
000070
000443
000343
000504
000203
000204
000943
000886
000337
000146
000778
001051
000121
000196
000647
001182
000972
000750
000568
000585
000925
001295
000280
000741
000360
000969
000042
000087
000975
000877
000942
000998
000356
000697
000403
000208
000514
000401
000248
000758
001192
001282
000876
000764
000445
000865
000564
000149
000584
001019
001017
000308
001318
000018
000056
000232
000777
000328
001217
001219
000577
000383
001276
000929
000354
000029
001118
000006
000105
000486
000074
000275
001189
000612
000008
001093
000034
000949
000039
001138
000677
001198
000696
001230
000590
001243
000323
001032
000488
000413
000700
000854
000217
000101
000518
000578
000950
000031
000888
000256
000419
000035
000872
001147
001033
000991
001113
000884
001292
000235
000185
000404
001071
000340
000387
000990
000254
000953
000770
000241
000559
000357
000905
000068
000097
000108
000618
000982
000743
000279
000824
001058
000681
001196
001275
000760
000089
000054
000624
000987
001179
000338
000206
000177
000565
001134
000130
000952
001082
000136
000551
000288
000619
000852
001165
000321
001231
000501
000393
000372
000512
000424
001211
000977
000580
000626
000218
000451
000566
000165
000964
000407
000517
000572
001027
001106
000662
000102
000776
000160
000072
001062
000899
001091
000024
000304
000753
000594
000769
000464
000044
000007
000438
001313
000069
001187
000993
000161
000505
001290
000546
000258
001064
001228
001115
000674
001293
001269
001172
001247
000667
000641
000172
000503
000181
000726
000033
000437
001263
000989
000376
000959
000491
001090
000530
001104
000796
000290
000179
000868
000264
000327
000657
000442
001035
001136
000251
000342
000098
001166
000829
001049
000104
000402
000339
000427
000078
001014
000139
001152
000479
000175
000435
001072
000246
000414
000638
000904
001321
000396
000094
000805
000005
001316
001046
000586
001226
000079
000725
001296
000827
000103
001099
001205
000368
000278
000190
000544
001260
000997
000431
000919
001197
001030
001173
000454
001157
000164
001037
000077
000642
000828
000675
000702
000310
000797
000739
000123
001270
001170
001117
000857
000477
000180
001130
000958
000855
001109
001009
000199
001203
000676
000132
000441
001261
000878
000283
000408
000244
000322
001105
000932
XDU514
XDU646
XDU904
XDU660
XDU347
XDU962
XDU92
XDU838
XDU907
XDU496
XDU83
XDU606
XDU307
XDU138
XDU357
XDU993
XDU693
XDU493
XDU891
XDU410
XDU288
XDU562
XDU849
XDU23
XDU199
XDU370
XDU537
XDU871
XDU656
XDU331
XDU328
XDU403
XDU230
XDU529
XDU229
XDU476
XDU792
XDU412
XDU689
XDU51
XDU532
XDU356
XDU303
XDU161
XDU879
XDU867
XDU773
XDU323
XDU836
XDU236
XDU749
XDU807
XDU72
XDU128
XDU822
XDU480
XDU270
XDU651
XDU815
XDU30
XDU548
XDU386
XDU555
XDU122
XDU798
XDU264
XDU725
XDU806
XDU440
XDU332
XDU875
XDU325
XDU486
XDU659
XDU835
XDU335
XDU330
XDU132
XDU89
XDU580
XDU8
XDU4
XDU280
XDU895
XDU869
XDU799
XDU419
XDU772
XDU896
XDU604
XDU116
XDU85
XDU91
XDU158
XDU130
XDU611
XDU98
XDU299
XDU114
XDU923
XDU94
XDU800
XDU934
XDU998
XDU338
XDU959
XDU712
XDU754
XDU636
XDU624
XDU918
XDU26
XDU559
XDU761
XDU909
XDU340
XDU262
XDU964
XDU29
XDU443
XDU929
XDU739
XDU654
XDU858
XDU551
XDU365
XDU665
XDU705
XDU434
XDU105
XDU887
XDU910
XDU261
XDU360
XDU912
XDU894
XDU727
XDU42
XDU556
XDU613
XDU285
XDU668
XDU592
XDU341
XDU942
XDU982
XDU191
XDU528
XDU720
XDU601
XDU531
XDU21
XDU275
XDU232
XDU344
XDU301
XDU977
XDU390
XDU938
XDU975
XDU688
XDU979
XDU58
XDU494
XDU988
XDU826
XDU571
XDU1
XDU779
XDU633
XDU987
XDU449
XDU313
XDU153
XDU718
XDU978
XDU538
XDU950
XDU740
XDU986
XDU780
XDU965
XDU900
XDU423
XDU194
XDU490
XDU73
XDU675
XDU254
XDU547
XDU983
XDU227
XDU68
XDU445
XDU903
XDU652
XDU642
XDU599
XDU64
XDU221
XDU291
XDU460
XDU623
XDU766
XDU680
XDU714
XDU25
XDU19
XDU65
XDU671
XDU333
XDU375
XDU550
XDU790
XDU397
XDU497
XDU645
XDU470
XDU913
XDU956
XDU218
XDU380
XDU487
XDU66
XDU324
XDU612
XDU384
XDU544
XDU144
XDU542
XDU248
XDU461
XDU148
XDU653
XDU336
XDU866
XDU456
XDU540
XDU351
XDU112
XDU272
XDU53
XDU515
XDU699
XDU417
XDU639
XDU342
XDU31
XDU448
XDU292
XDU728
XDU180
XDU149
XDU706
XDU973
XDU320
XDU625
XDU811
XDU462
XDU388
XDU638
XDU90
XDU109
XDU722
XDU162
XDU972
XDU767
XDU349
XDU263
XDU465
XDU576
XDU507
XDU644
XDU587
XDU255
XDU326
XDU500
XDU586
XDU524
XDU765
XDU890
XDU960
XDU594
XDU80
XDU14
XDU569
XDU953
XDU884
XDU282
XDU387
XDU579
XDU260
XDU252
XDU971
XDU905
XDU994
XDU36
XDU266
XDU405
XDU208
XDU207
XDU723
XDU35
XDU590
XDU678
XDU629
XDU939
XDU804
XDU948
XDU24
XDU967
XDU293
XDU545
XDU901
XDU290
XDU989
XDU322
XDU707
XDU188
XDU582
XDU810
XDU439
XDU300
XDU237
XDU457
XDU433
XDU831
XDU917
XDU677
XDU82
XDU561
XDU413
XDU834
XDU368
XDU658
XDU898
XDU173
XDU34
XDU467
XDU2
XDU265
XDU735
XDU454
XDU163
XDU81
XDU74
XDU140
XDU166
XDU478
XDU61
XDU880
XDU841
XDU565
XDU377
XDU839
XDU691
XDU607
XDU530
XDU844
XDU847
XDU472
XDU882
XDU17
XDU619
XDU12
XDU736
XDU859
XDU186
XDU985
XDU389
XDU921
XDU355
XDU539
XDU141
XDU916
XDU135
XDU519
XDU621
XDU435
XDU760
XDU783
XDU591
XDU628
XDU464
XDU649
XDU198
XDU560
XDU372
XDU553
XDU458
XDU183
XDU650
XDU622
XDU825
XDU471
XDU190
XDU414
XDU44
XDU741
XDU635
XDU647
XDU573
XDU864
XDU618
XDU364
XDU543
XDU437
XDU502
XDU824
XDU952
XDU125
XDU641
XDU491
XDU201
XDU947
XDU444
XDU79
XDU518
XDU32
XDU283
XDU802
XDU483
XDU59
XDU477
XDU670
XDU234
XDU577
XDU969
XDU669
XDU995
XDU425
XDU506
XDU970
XDU681
XDU353
XDU676
XDU958
XDU782
XDU121
XDU363
XDU411
XDU690
XDU392
XDU536
XDU752
XDU210
XDU774
XDU350
XDU731
XDU930
XDU479
XDU602
XDU856
XDU96
XDU520
XDU13
XDU915
XDU106
XDU853
XDU308
XDU47
XDU769
XDU242
XDU899
XDU484
XDU581
XDU643
XDU827
XDU246
XDU119
XDU746
XDU100
XDU311
XDU512
XDU852
XDU509
XDU874
XDU686
XDU139
XDU928
XDU719
XDU296
XDU750
XDU713
XDU156
XDU634
XDU6
XDU821
XDU488
XDU420
XDU748
XDU126
XDU474
XDU273
XDU742
XDU438
XDU20
XDU27
XDU452
XDU541
XDU598
XDU949
XDU992
XDU513
XDU155
XDU228
XDU206
XDU69
XDU666
XDU860
XDU136
XDU617
XDU436
XDU716
XDU823
XDU38
XDU1000
XDU851
XDU627
XDU974
XDU717
XDU734
XDU796
XDU481
XDU990
XDU679
XDU764
XDU238
XDU848
XDU908
XDU418
XDU696
XDU378
XDU724
XDU632
XDU182
XDU76
XDU791
XDU830
XDU814
XDU306
XDU931
XDU154
XDU564
XDU383
XDU473
XDU84
XDU143
XDU18
XDU683
XDU495
XDU78
XDU840
XDU382
XDU385
XDU233
XDU220
XDU616
XDU655
XDU797
XDU854
XDU312
XDU924
XDU593
XDU150
XDU60
XDU951
XDU812
XDU608
XDU408
XDU184
XDU552
XDU172
XDU820
XDU584
XDU314
XDU702
XDU174
XDU379
XDU511
XDU914
XDU214
XDU315
XDU535
XDU305
XDU294
XDU71
XDU534
XDU133
XDU204
XDU991
XDU475
XDU870
XDU793
XDU585
XDU431
XDU786
XDU944
XDU102
XDU245
XDU857
XDU427
XDU738
XDU726
XDU568
XDU843
XDU131
XDU298
XDU498
XDU837
XDU243
XDU588
XDU832
XDU526
XDU710
XDU177
XDU310
XDU165
XDU603
XDU99
XDU22
XDU615
XDU566
XDU286
XDU703
XDU361
XDU795
XDU277
XDU508
XDU521
XDU297
XDU317
XDU861
XDU271
XDU318
XDU572
XDU247
XDU202
XDU946
XDU466
XDU732
XDU226
XDU583
XDU120
XDU926
XDU401
XDU687
XDU984
XDU819
XDU664
XDU400
XDU446
XDU453
XDU730
XDU362
XDU62
XDU175
XDU809
XDU430
XDU124
XDU346
XDU776
XDU605
XDU187
XDU337
XDU211
XDU684
XDU179
XDU981
XDU784
XDU701
XDU358
XDU768
XDU911
XDU235
XDU215
XDU145
XDU609
XDU281
XDU432
XDU196
XDU499
XDU250
XDU304
XDU600
XDU309
XDU171
XDU787
XDU595
XDU808
XDU7
XDU846
XDU428
XDU287
XDU729
XDU213
XDU828
XDU941
XDU395
XDU756
XDU897
XDU239
XDU610
XDU251
XDU373
XDU533
XDU95
XDU57
XDU945
XDU222
XDU168
XDU137
XDU961
XDU906
XDU937
XDU0
XDU770
XDU268
XDU963
XDU113
XDU771
XDU763
XDU339
XDU52
XDU737
XDU755
XDU159
XDU626
XDU16
XDU118
XDU77
XDU574
XDU402
XDU407
XDU88
XDU778
XDU391
XDU15
XDU940
XDU886
XDU359
XDU424
XDU721
XDU399
XDU345
XDU157
XDU107
XDU873
XDU865
XDU39
XDU893
XDU976
XDU695
XDU367
XDU700
XDU422
XDU936
XDU123
XDU503
XDU366
XDU101
XDU631
XDU276
XDU549
XDU212
XDU197
XDU640
XDU200
XDU37
XDU469
XDU522
XDU575
XDU256
XDU409
XDU152
XDU224
XDU86
XDU630
XDU980
XDU813
XDU70
XDU249
XDU396
XDU11
XDU327
XDU269
XDU284
XDU757
XDU348
XDU554
XDU127
XDU267
XDU751
XDU181
XDU468
XDU274
XDU55
XDU510
XDU170
XDU744
================================================
FILE: metrics.py
================================================
import numpy as np
import torch
from skimage import measure
class ROCMetric():
"""Computes pixAcc and mIoU metric scores
"""
def __init__(self, nclass, bins):
# bin的意义实际上是确定ROC曲线上的threshold取多少个离散值
# nclass :有几个类别 红外弱小目标检测只有一个类别
super(ROCMetric, self).__init__()
self.nclass = nclass
self.bins = bins
self.tp_arr = np.zeros(self.bins + 1)
self.pos_arr = np.zeros(self.bins + 1)
self.fp_arr = np.zeros(self.bins + 1)
self.neg_arr = np.zeros(self.bins + 1)
self.class_pos = np.zeros(self.bins + 1)
# self.reset()
# 网络输入的结果和标签 计算两者之前的东西
def update(self, preds, labels):
for iBin in range(self.bins + 1):
# score_thresh = (iBin + 0.0) / self.bins
score_thresh = -30 + iBin * (255 / self.bins)
# print(iBin, "-th, score_thresh: ", score_thresh)
i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)
self.tp_arr[iBin] += i_tp
self.pos_arr[iBin] += i_pos
self.fp_arr[iBin] += i_fp
self.neg_arr[iBin] += i_neg
self.class_pos[iBin] += i_class_pos
def get(self):
tp_rates = self.tp_arr / (self.pos_arr + 0.001) # tp_rates = recall = TP/(TP+FN)
fp_rates = self.fp_arr / (self.neg_arr + 0.001) # fp_rates = FP/(FP+TN)
FP = self.fp_arr / (self.neg_arr + self.pos_arr)
recall = self.tp_arr / (self.pos_arr + 0.001) # recall = TP/(TP+FN)
precision = self.tp_arr / (self.class_pos + 0.001) # precision = TP/(TP+FP)
return tp_rates, fp_rates, recall, precision, FP
def reset(self):
self.tp_arr = np.zeros([11])
self.pos_arr = np.zeros([11])
self.fp_arr = np.zeros([11])
self.neg_arr = np.zeros([11])
self.class_pos = np.zeros([11])
class mIoU():
def __init__(self):
super(mIoU, self).__init__()
self.reset()
def update(self, preds, labels):
correct, labeled = batch_pix_accuracy(preds, labels)
inter, union = batch_intersection_union(preds, labels)
self.total_correct += correct # 预测正确的像素数
self.total_label += labeled # GT目标的像素数
self.total_inter += inter # 交集
self.total_union += union # 并集
def get(self):
pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
mIoU = IoU.mean()
return float(pixAcc), mIoU
def reset(self):
self.total_inter = 0
self.total_union = 0
self.total_correct = 0
self.total_label = 0
class PD_FA():
def __init__(self, ):
super(PD_FA, self).__init__()
self.image_area_total = []
self.image_area_match = []
self.dismatch_pixel = 0
self.all_pixel = 0
self.PD = 0
self.target = 0
def update(self, preds, labels, size):
predits = np.array((preds).cpu()).astype('int64')
labelss = np.array((labels).cpu()).astype('int64')
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss, connectivity=2)
coord_label = measure.regionprops(label)
self.target += len(coord_label) # 目标总数 直接就搞GT的连通域个数
self.image_area_total = [] # 图像中预测的区域列表
self.image_area_match = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
for i in range(len(coord_label)): # image 与 label 之间 根据中心点 进行连通域的确定
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
self.image_area_match.append(area_image)
del coord_image[m] # 匹配上一个之后就 清除一个
break
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] # 在image里面 但是不在label里面
self.dismatch_pixel += np.sum(self.dismatch) # Fa 虚警
self.all_pixel += size[0] * size[1]
self.PD += len(self.distance_match) # 如果中心点之间距离在3一下 就算Pd 所以Pd 是匹配上了的目标的个数
def get(self):
Final_FA = self.dismatch_pixel / self.all_pixel
Final_PD = self.PD / self.target
return Final_PD, float(Final_FA.cpu().detach().numpy())
def reset(self):
self.FA = np.zeros([self.bins + 1])
self.PD = np.zeros([self.bins + 1])
def batch_pix_accuracy(output, target):
if len(target.shape) == 3:
target = np.expand_dims(target.float(), axis=1)
elif len(target.shape) == 4:
target = target.float()
else:
raise ValueError("Unknown target dimension")
assert output.shape == target.shape, "Predict and Label Shape Don't Match"
predict = (output > 0).float() # 将output 从 True Flase 转成 1 0
pixel_labeled = (target > 0).float().sum() # GF中 1的个数
pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum() # 预测对的个数
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(output, target):
mini = 1
maxi = 1
nbins = 1
predict = (output > 0).float()
if len(target.shape) == 3:
target = np.expand_dims(target.float(), axis=1)
elif len(target.shape) == 4:
target = target.float()
else:
raise ValueError("Unknown target dimension")
intersection = predict * ((predict == target).float())
area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Error: Intersection area should be smaller than Union area"
return area_inter, area_union
================================================
FILE: model/Config.py
================================================
# -*- coding: utf-8 -*-
# @Author : Shuai Yuan
# @File : Config.py
# @Software: PyCharm
# coding=utf-8
import os
import torch
import time
import ml_collections
##########################################################################
# SCTrans configs
##########################################################################
def get_SCTrans_config():
config = ml_collections.ConfigDict()
config.transformer = ml_collections.ConfigDict()
config.KV_size = 480 # KV_size = Q1 + Q2 + Q3 + Q4
config.transformer.num_heads = 4
config.transformer.num_layers = 4
config.patch_sizes = [16, 8, 4, 2]
config.base_channel = 32 # base channel of U-Net
config.n_classes = 1
# ********** unused **********
config.transformer.embeddings_dropout_rate = 0.1
config.transformer.attention_dropout_rate = 0.1
config.transformer.dropout_rate = 0
return config
================================================
FILE: model/SCTransNet.py
================================================
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# @Author : Shuai Yuan
# @File : SCTransNet.py
# @Software: PyCharm
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import math
from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
import torch.nn as nn
import torch
import torch.nn.functional as F
import ml_collections
from einops import rearrange
import numbers
from thop import profile
def get_CTranS_config():
config = ml_collections.ConfigDict()
config.transformer = ml_collections.ConfigDict()
config.KV_size = 480 # KV_size = Q1 + Q2 + Q3 + Q4
config.transformer.num_heads = 4
config.transformer.num_layers = 4
config.patch_sizes = [16, 8, 4, 2]
config.base_channel = 32 # base channel of U-Net
config.n_classes = 1
# ********** useless **********
config.transformer.embeddings_dropout_rate = 0.1
config.transformer.attention_dropout_rate = 0.1
config.transformer.dropout_rate = 0
return config
class Channel_Embeddings(nn.Module):
def __init__(self, config, patchsize, img_size, in_channels):
super().__init__()
img_size = _pair(img_size)
patch_size = _pair(patchsize)
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) # 14 * 14 = 196
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=in_channels,
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
self.dropout = Dropout(config.transformer["embeddings_dropout_rate"])
def forward(self, x):
if x is None:
return None
x = self.patch_embeddings(x)
return x
class Reconstruct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
super(Reconstruct, self).__init__()
if kernel_size == 3:
padding = 1
else:
padding = 0
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.norm = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU(inplace=True)
self.scale_factor = scale_factor
# def forward(self, x, h, w):
def forward(self, x):
if x is None:
return None
x = nn.Upsample(scale_factor=self.scale_factor, mode='bilinear')(x)
out = self.conv(x)
out = self.norm(out)
out = self.activation(out)
return out
# spatial-embedded Single-head Channel-cross Attention (SSCA)
class Attention_org(nn.Module):
def __init__(self, config, vis, channel_num):
super(Attention_org, self).__init__()
self.vis = vis
self.KV_size = config.KV_size
self.channel_num = channel_num
self.num_attention_heads = 1
self.psi = nn.InstanceNorm2d(self.num_attention_heads)
self.softmax = Softmax(dim=3)
# self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.mhead1 = nn.Conv2d(channel_num[0], channel_num[0] * self.num_attention_heads, kernel_size=1, bias=False)
self.mhead2 = nn.Conv2d(channel_num[1], channel_num[1] * self.num_attention_heads, kernel_size=1, bias=False)
self.mhead3 = nn.Conv2d(channel_num[2], channel_num[2] * self.num_attention_heads, kernel_size=1, bias=False)
self.mhead4 = nn.Conv2d(channel_num[3], channel_num[3] * self.num_attention_heads, kernel_size=1, bias=False)
self.mheadk = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False)
self.mheadv = nn.Conv2d(self.KV_size, self.KV_size * self.num_attention_heads, kernel_size=1, bias=False)
self.q1 = nn.Conv2d(channel_num[0] * self.num_attention_heads, channel_num[0] * self.num_attention_heads, kernel_size=3, stride=1,
padding=1,
groups=channel_num[0] * self.num_attention_heads // 2, bias=False)
self.q2 = nn.Conv2d(channel_num[1] * self.num_attention_heads, channel_num[1] * self.num_attention_heads, kernel_size=3, stride=1,
padding=1,
groups=channel_num[1] * self.num_attention_heads // 2, bias=False)
self.q3 = nn.Conv2d(channel_num[2] * self.num_attention_heads, channel_num[2] * self.num_attention_heads, kernel_size=3, stride=1,
padding=1,
groups=channel_num[2] * self.num_attention_heads // 2, bias=False)
self.q4 = nn.Conv2d(channel_num[3] * self.num_attention_heads, channel_num[3] * self.num_attention_heads, kernel_size=3, stride=1,
padding=1,
groups=channel_num[3] * self.num_attention_heads // 2, bias=False)
self.k = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1,
padding=1, groups=self.KV_size * self.num_attention_heads, bias=False)
self.v = nn.Conv2d(self.KV_size * self.num_attention_heads, self.KV_size * self.num_attention_heads, kernel_size=3, stride=1,
padding=1, groups=self.KV_size * self.num_attention_heads, bias=False)
self.project_out1 = nn.Conv2d(channel_num[0], channel_num[0], kernel_size=1, bias=False)
self.project_out2 = nn.Conv2d(channel_num[1], channel_num[1], kernel_size=1, bias=False)
self.project_out3 = nn.Conv2d(channel_num[2], channel_num[2], kernel_size=1, bias=False)
self.project_out4 = nn.Conv2d(channel_num[3], channel_num[3], kernel_size=1, bias=False)
# ****************** useless ***************************************
self.q1_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q1_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q1_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q1_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q2_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q2_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q2_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q2_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q3_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q3_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q3_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q3_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q4_attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q4_attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q4_attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.q4_attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
def forward(self, emb1, emb2, emb3, emb4, emb_all):
b, c, h, w = emb1.shape
q1 = self.q1(self.mhead1(emb1))
q2 = self.q2(self.mhead2(emb2))
q3 = self.q3(self.mhead3(emb3))
q4 = self.q4(self.mhead4(emb4))
k = self.k(self.mheadk(emb_all))
v = self.v(self.mheadv(emb_all))
# k, v = kv.chunk(2, dim=1)
q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
q3 = rearrange(q3, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
q4 = rearrange(q4, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_attention_heads)
q1 = torch.nn.functional.normalize(q1, dim=-1)
q2 = torch.nn.functional.normalize(q2, dim=-1)
q3 = torch.nn.functional.normalize(q3, dim=-1)
q4 = torch.nn.functional.normalize(q4, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
_, _, c1, _ = q1.shape
_, _, c2, _ = q2.shape
_, _, c3, _ = q3.shape
_, _, c4, _ = q4.shape
_, _, c, _ = k.shape
attn1 = (q1 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)
attn2 = (q2 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)
attn3 = (q3 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)
attn4 = (q4 @ k.transpose(-2, -1)) / math.sqrt(self.KV_size)
attention_probs1 = self.softmax(self.psi(attn1))
attention_probs2 = self.softmax(self.psi(attn2))
attention_probs3 = self.softmax(self.psi(attn3))
attention_probs4 = self.softmax(self.psi(attn4))
out1 = (attention_probs1 @ v)
out2 = (attention_probs2 @ v)
out3 = (attention_probs3 @ v)
out4 = (attention_probs4 @ v)
out_1 = out1.mean(dim=1)
out_2 = out2.mean(dim=1)
out_3 = out3.mean(dim=1)
out_4 = out4.mean(dim=1)
out_1 = rearrange(out_1, 'b c (h w) -> b c h w', h=h, w=w)
out_2 = rearrange(out_2, 'b c (h w) -> b c h w', h=h, w=w)
out_3 = rearrange(out_3, 'b c (h w) -> b c h w', h=h, w=w)
out_4 = rearrange(out_4, 'b c (h w) -> b c h w', h=h, w=w)
O1 = self.project_out1(out_1)
O2 = self.project_out2(out_2)
O3 = self.project_out3(out_3)
O4 = self.project_out4(out_4)
weights = None
return O1, O2, O3, O4, weights
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm3d(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm3d, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
class eca_layer_2d(nn.Module):
def __init__(self, channel, k_size=3):
super(eca_layer_2d, self).__init__()
padding = k_size // 2
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.conv = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False),
nn.Sigmoid()
)
self.channel = channel
self.k_size = k_size
def forward(self, x):
out = self.avg_pool(x)
out = out.view(x.size(0), 1, x.size(1))
out = self.conv(out)
out = out.view(x.size(0), x.size(1), 1, 1)
return out * x
# Complementary Feed-forward Network (CFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv3x3 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features,
bias=bias)
self.dwconv5x5 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, groups=hidden_features,
bias=bias)
self.relu3 = nn.ReLU()
self.relu5 = nn.ReLU()
self.project_out = nn.Conv2d(hidden_features * 2, dim, kernel_size=1, bias=bias)
self.eca = eca_layer_2d(dim)
def forward(self, x):
x_3,x_5 = self.project_in(x).chunk(2, dim=1)
x1_3 = self.relu3(self.dwconv3x3(x_3))
x1_5 = self.relu5(self.dwconv5x5(x_5))
x = torch.cat([x1_3, x1_5], dim=1)
x = self.project_out(x)
x = self.eca(x)
return x
# Spatial-channel Cross Transformer Block (SCTB)
class Block_ViT(nn.Module):
def __init__(self, config, vis, channel_num):
super(Block_ViT, self).__init__()
self.attn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')
self.attn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')
self.attn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')
self.attn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')
self.attn_norm = LayerNorm3d(config.KV_size, LayerNorm_type='WithBias')
self.channel_attn = Attention_org(config, vis, channel_num)
self.ffn_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')
self.ffn_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')
self.ffn_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')
self.ffn_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')
self.ffn1 = FeedForward(channel_num[0], ffn_expansion_factor=2.66, bias=False)
self.ffn2 = FeedForward(channel_num[1], ffn_expansion_factor=2.66, bias=False)
self.ffn3 = FeedForward(channel_num[2], ffn_expansion_factor=2.66, bias=False)
self.ffn4 = FeedForward(channel_num[3], ffn_expansion_factor=2.66, bias=False)
def forward(self, emb1, emb2, emb3, emb4):
embcat = []
org1 = emb1
org2 = emb2
org3 = emb3
org4 = emb4
for i in range(4):
var_name = "emb" + str(i + 1)
tmp_var = locals()[var_name]
if tmp_var is not None:
embcat.append(tmp_var)
emb_all = torch.cat(embcat, dim=1)
cx1 = self.attn_norm1(emb1) if emb1 is not None else None
cx2 = self.attn_norm2(emb2) if emb2 is not None else None
cx3 = self.attn_norm3(emb3) if emb3 is not None else None
cx4 = self.attn_norm4(emb4) if emb4 is not None else None
emb_all = self.attn_norm(emb_all) # 1 196 960
cx1, cx2, cx3, cx4, weights = self.channel_attn(cx1, cx2, cx3, cx4, emb_all)
cx1 = org1 + cx1 if emb1 is not None else None
cx2 = org2 + cx2 if emb2 is not None else None
cx3 = org3 + cx3 if emb3 is not None else None
cx4 = org4 + cx4 if emb4 is not None else None
org1 = cx1
org2 = cx2
org3 = cx3
org4 = cx4
x1 = self.ffn_norm1(cx1) if emb1 is not None else None
x2 = self.ffn_norm2(cx2) if emb2 is not None else None
x3 = self.ffn_norm3(cx3) if emb3 is not None else None
x4 = self.ffn_norm4(cx4) if emb4 is not None else None
x1 = self.ffn1(x1) if emb1 is not None else None
x2 = self.ffn2(x2) if emb2 is not None else None
x3 = self.ffn3(x3) if emb3 is not None else None
x4 = self.ffn4(x4) if emb4 is not None else None
x1 = x1 + org1 if emb1 is not None else None
x2 = x2 + org2 if emb2 is not None else None
x3 = x3 + org3 if emb3 is not None else None
x4 = x4 + org4 if emb4 is not None else None
return x1, x2, x3, x4, weights
class Encoder(nn.Module):
def __init__(self, config, vis, channel_num):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm1 = LayerNorm3d(channel_num[0], LayerNorm_type='WithBias')
self.encoder_norm2 = LayerNorm3d(channel_num[1], LayerNorm_type='WithBias')
self.encoder_norm3 = LayerNorm3d(channel_num[2], LayerNorm_type='WithBias')
self.encoder_norm4 = LayerNorm3d(channel_num[3], LayerNorm_type='WithBias')
for _ in range(config.transformer["num_layers"]):
layer = Block_ViT(config, vis, channel_num)
self.layer.append(copy.deepcopy(layer))
def forward(self, emb1, emb2, emb3, emb4):
attn_weights = []
for layer_block in self.layer:
emb1, emb2, emb3, emb4, weights = layer_block(emb1, emb2, emb3, emb4)
if self.vis:
attn_weights.append(weights)
emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
return emb1, emb2, emb3, emb4, attn_weights
class ChannelTransformer(nn.Module):
def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
super().__init__()
self.patchSize_1 = patchSize[0]
self.patchSize_2 = patchSize[1]
self.patchSize_3 = patchSize[2]
self.patchSize_4 = patchSize[3]
self.embeddings_1 = Channel_Embeddings(config, self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
self.embeddings_2 = Channel_Embeddings(config, self.patchSize_2, img_size=img_size // 2, in_channels=channel_num[1])
self.embeddings_3 = Channel_Embeddings(config, self.patchSize_3, img_size=img_size // 4, in_channels=channel_num[2])
self.embeddings_4 = Channel_Embeddings(config, self.patchSize_4, img_size=img_size // 8, in_channels=channel_num[3])
self.encoder = Encoder(config, vis, channel_num)
self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1, scale_factor=(self.patchSize_1, self.patchSize_1))
self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1, scale_factor=(self.patchSize_2, self.patchSize_2))
self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1, scale_factor=(self.patchSize_3, self.patchSize_3))
self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1, scale_factor=(self.patchSize_4, self.patchSize_4))
def forward(self, en1, en2, en3, en4):
emb1 = self.embeddings_1(en1)
emb2 = self.embeddings_2(en2)
emb3 = self.embeddings_3(en3)
emb4 = self.embeddings_4(en4)
encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1, emb2, emb3, emb4) # (B, n_patch, hidden)
x1 = self.reconstruct_1(encoded1) if en1 is not None else None
x2 = self.reconstruct_2(encoded2) if en2 is not None else None
x3 = self.reconstruct_3(encoded3) if en3 is not None else None
x4 = self.reconstruct_4(encoded4) if en4 is not None else None
x1 = x1 + en1 if en1 is not None else None
x2 = x2 + en2 if en2 is not None else None
x3 = x3 + en3 if en3 is not None else None
x4 = x4 + en4 if en4 is not None else None
return x1, x2, x3, x4, attn_weights
def get_activation(activation_type):
activation_type = activation_type.lower()
if hasattr(nn, activation_type):
return getattr(nn, activation_type)()
else:
return nn.ReLU()
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
layers = []
layers.append(CBN(in_channels, out_channels, activation))
for _ in range(nb_Conv - 1):
layers.append(CBN(out_channels, out_channels, activation))
return nn.Sequential(*layers)
class CBN(nn.Module):
def __init__(self, in_channels, out_channels, activation='ReLU'):
super(CBN, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=3, padding=1)
self.norm = nn.BatchNorm2d(out_channels)
self.activation = get_activation(activation)
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
return self.activation(out)
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
super(DownBlock, self).__init__()
self.maxpool = nn.MaxPool2d(2)
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
def forward(self, x):
out = self.maxpool(x)
return self.nConvs(out)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class CCA(nn.Module):
def __init__(self, F_g, F_x):
super().__init__()
self.mlp_x = nn.Sequential(
Flatten(),
nn.Linear(F_x, F_x))
self.mlp_g = nn.Sequential(
Flatten(),
nn.Linear(F_g, F_x))
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
avg_pool_x = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_x = self.mlp_x(avg_pool_x)
avg_pool_g = F.avg_pool2d(g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
channel_att_g = self.mlp_g(avg_pool_g)
channel_att_sum = (channel_att_x + channel_att_g) / 2.0
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
x_after_channel = x * scale
out = self.relu(x_after_channel)
return out
class UpBlock_attention(nn.Module):
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
super().__init__()
self.up = nn.Upsample(scale_factor=2)
self.coatt = CCA(F_g=in_channels // 2, F_x=in_channels // 2)
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
def forward(self, x, skip_x):
up = self.up(x)
skip_x_att = self.coatt(g=up, x=skip_x)
x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
return self.nConvs(x)
class Res_block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(Res_block, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# self.fca = FCA_Layer(out_channels)
if stride != 1 or out_channels != in_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels))
else:
self.shortcut = None
def forward(self, x):
residual = x
if self.shortcut is not None:
residual = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class SCTransNet(nn.Module):
def __init__(self, config, n_channels=1, n_classes=1, img_size=256, vis=False, mode='train', deepsuper=True):
super().__init__()
self.vis = vis
self.deepsuper = deepsuper
print('Deep-Supervision:', deepsuper)
self.mode = mode
self.n_channels = n_channels
self.n_classes = n_classes
in_channels = config.base_channel # basic channel 64
block = Res_block
self.pool = nn.MaxPool2d(2, 2)
self.inc = self._make_layer(block, n_channels, in_channels)
self.down_encoder1 = self._make_layer(block, in_channels, in_channels * 2, 1) # 64 128
self.down_encoder2 = self._make_layer(block, in_channels * 2, in_channels * 4, 1) # 64 128
self.down_encoder3 = self._make_layer(block, in_channels * 4, in_channels * 8, 1) # 64 128
self.down_encoder4 = self._make_layer(block, in_channels * 8, in_channels * 8, 1) # 64 128
self.mtc = ChannelTransformer(config, vis, img_size,
channel_num=[in_channels, in_channels * 2, in_channels * 4, in_channels * 8],
patchSize=config.patch_sizes)
self.up_decoder4 = UpBlock_attention(in_channels * 16, in_channels * 4, nb_Conv=2)
self.up_decoder3 = UpBlock_attention(in_channels * 8, in_channels * 2, nb_Conv=2)
self.up_decoder2 = UpBlock_attention(in_channels * 4, in_channels, nb_Conv=2)
self.up_decoder1 = UpBlock_attention(in_channels * 2, in_channels, nb_Conv=2)
self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1, 1), stride=(1, 1))
if self.deepsuper:
self.gt_conv5 = nn.Sequential(nn.Conv2d(in_channels * 8, 1, 1))
self.gt_conv4 = nn.Sequential(nn.Conv2d(in_channels * 4, 1, 1))
self.gt_conv3 = nn.Sequential(nn.Conv2d(in_channels * 2, 1, 1))
self.gt_conv2 = nn.Sequential(nn.Conv2d(in_channels * 1, 1, 1))
self.outconv = nn.Conv2d(5 * 1, 1, 1)
def _make_layer(self, block, input_channels, output_channels, num_blocks=1):
layers = []
layers.append(block(input_channels, output_channels))
for i in range(num_blocks - 1):
layers.append(block(output_channels, output_channels))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.inc(x) # 64 224 224
x2 = self.down_encoder1(self.pool(x1)) # 128 112 112
x3 = self.down_encoder2(self.pool(x2)) # 256 56 56
x4 = self.down_encoder3(self.pool(x3)) # 512 28 28
d5 = self.down_encoder4(self.pool(x4)) # 512 14 14
# CCT
f1 = x1
f2 = x2
f3 = x3
f4 = x4
# CCT
x1, x2, x3, x4, att_weights = self.mtc(x1, x2, x3, x4)
x1 = x1 + f1
x2 = x2 + f2
x3 = x3 + f3
x4 = x4 + f4
# Feature fusion
d4 = self.up_decoder4(d5, x4)
d3 = self.up_decoder3(d4, x3)
d2 = self.up_decoder2(d3, x2)
out = self.outc(self.up_decoder1(d2, x1))
# deep supervision
if self.deepsuper:
gt_5 = self.gt_conv5(d5)
gt_4 = self.gt_conv4(d4)
gt_3 = self.gt_conv3(d3)
gt_2 = self.gt_conv2(d2)
# 原始深监督
gt5 = F.interpolate(gt_5, scale_factor=16, mode='bilinear', align_corners=True)
gt4 = F.interpolate(gt_4, scale_factor=8, mode='bilinear', align_corners=True)
gt3 = F.interpolate(gt_3, scale_factor=4, mode='bilinear', align_corners=True)
gt2 = F.interpolate(gt_2, scale_factor=2, mode='bilinear', align_corners=True)
d0 = self.outconv(torch.cat((gt2, gt3, gt4, gt5, out), 1))
if self.mode == 'train':
return (torch.sigmoid(gt5), torch.sigmoid(gt4), torch.sigmoid(gt3), torch.sigmoid(gt2), torch.sigmoid(d0), torch.sigmoid(out))
else:
return torch.sigmoid(out)
else:
return torch.sigmoid(out)
if __name__ == '__main__':
config_vit = get_CTranS_config()
model = SCTransNet(config_vit, mode='train', deepsuper=True)
model = model
inputs = torch.rand(1, 1, 256, 256)
output = model(inputs)
flops, params = profile(model, (inputs,))
print("-" * 50)
print('FLOPs = ' + str(flops / 1000 ** 3) + ' G')
print('Params = ' + str(params / 1000 ** 2) + ' M')
================================================
FILE: test.py
================================================
import argparse
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import threading
from dataset import *
import time
from collections import OrderedDict
from model.SCTransNet import SCTransNet as SCTransNet
# from loss import *
import model.Config as config
import numpy as np
import torch
from skimage import measure
def cal_tp_pos_fp_neg(output, target, nclass, score_thresh):
predict = (output > score_thresh).float()
if len(target.shape) == 3:
print('????') # 加一个维度 使得target与 output的size一致
target = target.unsqueeze(dim=0)
# target = np.expand_dims(target.float(), axis=1)
target.to('cuda', torch.float)
elif len(target.shape) == 4:
target = target.float()
else:
raise ValueError("Unknown target dimension")
# 现在predict中高于阈值的部分为全1矩阵 target是GT
intersection = predict * ((predict == target).float())
tp = intersection.sum() # 对的预测为对的
fp = (predict * ((predict != target).float())).sum() # 错的预测为对的 虚警像素数
tn = ((1 - predict) * ((predict == target).float())).sum() # 错的预测为错的
fn = (((predict != target).float()) * (1 - predict)).sum() # 对的预测为错的
pos = tp + fn # 标签中 阳性的个数
neg = fp + tn # 标签中 阴性的个数
class_pos = tp + fp # 检测出的个数
return tp, pos, fp, neg, class_pos
class SamplewiseSigmoidMetric(object):
"""Computes pixAcc and mIoU metric scores
"""
def __init__(self, nclass, score_thresh=0.5):
self.nclass = nclass
self.score_thresh = score_thresh
self.lock = threading.Lock()
self.reset()
def update(self, preds, labels):
"""Updates the internal evaluation result.
Parameters
----------
labels : 'NDArray' or list of `NDArray`
The labels of the data.
preds : 'NDArray' or list of `NDArray`
Predicted values.
"""
def evaluate_worker(self, label, pred):
inter_arr, union_arr = batch_intersection_union_n(
pred, label, self.nclass, self.score_thresh)
with self.lock:
self.total_inter = np.append(self.total_inter, inter_arr)
self.total_union = np.append(self.total_union, union_arr)
if isinstance(preds, torch.Tensor):
evaluate_worker(self, labels, preds)
elif isinstance(preds, (list, tuple)):
threads = [threading.Thread(target=evaluate_worker,
args=(self, label, pred),
)
for (label, pred) in zip(labels, preds)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def get(self):
"""Gets the current evaluation result.
Returns
-------
metrics : tuple of float
pixAcc and mIoU
"""
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
nIoU = IoU.mean()
return nIoU
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.total_inter = np.array([])
self.total_union = np.array([])
self.total_correct = np.array([])
self.total_label = np.array([])
def batch_intersection_union_n(output, target, nclass, score_thresh):
"""nIoU"""
mini = 1
maxi = 1 # nclass
nbins = 1 # nclass
outputnp = output.detach().cpu().numpy()
# outputsig = F.sigmoid(output).detach().cpu().numpy()
# outputsig = nd.sigmoid(output).asnumpy()
predict = (outputnp > 0.5).astype('int64')
# predict = predict.detach().cpu().numpy()
# predict = (output.asnumpy() > 0).astype('int64') # P
if len(target.shape) == 3:
target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T
elif len(target.shape) == 4:
target = target.cpu().numpy().astype('int64') # T
else:
raise ValueError("Unknown target dimension")
intersection = predict * (predict == target) # TP 交集
num_sample = intersection.shape[0]
area_inter_arr = np.zeros(num_sample)
area_pred_arr = np.zeros(num_sample)
area_lab_arr = np.zeros(num_sample)
area_union_arr = np.zeros(num_sample)
for b in range(num_sample):
# areas of intersection and union
area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi))
area_inter_arr[b] = area_inter
area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi))
area_pred_arr[b] = area_pred
area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi))
area_lab_arr[b] = area_lab
area_union = area_pred + area_lab - area_inter
area_union_arr[b] = area_union
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter_arr, area_union_arr
class ROCMetric05():
"""Computes pixAcc and mIoU metric scores
"""
def __init__(self, nclass, bins):
# bin的意义实际上是确定ROC曲线上的threshold取多少个离散值
# nclass :有几个类别 红外弱小目标检测只有一个类别
super(ROCMetric05, self).__init__()
self.nclass = nclass
self.bins = bins
self.tp_arr = np.zeros(self.bins + 1)
self.pos_arr = np.zeros(self.bins + 1)
self.fp_arr = np.zeros(self.bins + 1)
self.neg_arr = np.zeros(self.bins + 1)
self.class_pos = np.zeros(self.bins + 1)
# self.reset()
# 网络输入的结果和标签 计算两者之前的东西
def update(self, preds, labels):
for iBin in range(self.bins + 1):
# score_thresh = (iBin + 0.0) / self.bins
score_thresh = (0.0 + iBin) / self.bins
# print(iBin, "-th, score_thresh: ", score_thresh)
i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)
self.tp_arr[iBin] += i_tp
self.pos_arr[iBin] += i_pos
self.fp_arr[iBin] += i_fp # 虚警像素数
self.neg_arr[iBin] += i_neg
self.class_pos[iBin] += i_class_pos
def get(self):
tp_rates = self.tp_arr / (self.pos_arr + 0.001) # tp_rates = recall = TP/(TP+FN)
fp_rates = self.fp_arr / (self.neg_arr + 0.001) # fp_rates = FP/(FP+TN)
FP = self.fp_arr / (self.neg_arr + self.pos_arr)
recall = self.tp_arr / (self.pos_arr + 0.001) # recall = TP/(TP+FN)
precision = self.tp_arr / (self.class_pos + 0.001) # precision = TP/(TP+FP)
f1_score = (2.0 * recall[5] * precision[5]) / (recall[5] + precision[5] + 0.00001)
return tp_rates, fp_rates, recall, precision, FP, f1_score
def reset(self):
self.tp_arr = np.zeros([11])
self.pos_arr = np.zeros([11])
self.fp_arr = np.zeros([11])
self.neg_arr = np.zeros([11])
self.class_pos = np.zeros([11])
class mIoU():
def __init__(self):
super(mIoU, self).__init__()
self.reset()
def update(self, preds, labels):
correct, labeled = batch_pix_accuracy(preds, labels) # labeled: GT中目标的像素数目 correct:预测正确的像素数
inter, union = batch_intersection_union(preds, labels)
self.total_correct += correct
self.total_label += labeled
self.total_inter += inter
self.total_union += union
def get(self):
pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
mIoU = IoU.mean()
return float(pixAcc), mIoU
def reset(self):
self.total_inter = 0
self.total_union = 0
self.total_correct = 0
self.total_label = 0
class PDFA():
def __init__(self, ):
super(PDFA, self).__init__()
self.image_area_total = []
self.image_area_match = []
self.dismatch_pixel = 0
self.all_pixel = 0
self.PD = 0
self.target = 0
def update(self, preds, labels, size):
predits = np.array((preds).cpu()).astype('int64')
labelss = np.array((labels).cpu()).astype('int64')
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss, connectivity=2)
coord_label = measure.regionprops(label)
self.target += len(coord_label)
self.image_area_total = []
self.image_area_match = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
for i in range(len(coord_label)):
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
self.image_area_match.append(area_image)
del coord_image[m]
break
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]
self.dismatch_pixel += np.sum(self.dismatch)
self.all_pixel += size[0] * size[1]
self.PD += len(self.distance_match)
def get(self):
Final_FA = self.dismatch_pixel / self.all_pixel
Final_PD = self.PD / self.target
return Final_PD, float(Final_FA.cpu().detach().numpy())
def reset(self):
self.FA = np.zeros([self.bins + 1])
self.PD = np.zeros([self.bins + 1])
def batch_pix_accuracy(output, target):
if len(target.shape) == 3:
target = np.expand_dims(target.float(), axis=1)
elif len(target.shape) == 4:
target = target.float()
else:
raise ValueError("Unknown target dimension")
assert output.shape == target.shape, "Predict and Label Shape Don't Match"
predict = (output > 0).float()
pixel_labeled = (target > 0).float().sum()
pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum()
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(output, target):
mini = 1
maxi = 1
nbins = 1
predict = (output > 0).float()
if len(target.shape) == 3:
target = np.expand_dims(target.float(), axis=1)
elif len(target.shape) == 4:
target = target.float()
else:
raise ValueError("Unknown target dimension")
intersection = predict * ((predict == target).float())
area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Error: Intersection area should be smaller than Union area"
return area_inter, area_union
class PD_FA():
def __init__(self, ):
super(PD_FA, self).__init__()
self.image_area_total = []
self.image_area_match = []
self.dismatch_pixel = 0
self.all_pixel = 0
self.PD = 0
self.target = 0
def update(self, preds, labels, size):
predits = np.array((preds).cpu()).astype('int64')
labelss = np.array((labels).cpu()).astype('int64')
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss, connectivity=2)
coord_label = measure.regionprops(label)
self.target += len(coord_label) # 目标总数 直接就搞GT的连通域个数
self.image_area_total = [] # 图像中预测的区域列表
self.image_area_match = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
for i in range(len(coord_label)): # image 与 label 之间 根据中心点 进行连通域的确定
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
self.image_area_match.append(area_image)
del coord_image[m] # 匹配上一个之后就 清除一个
break
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] # 在image里面 但是不在label里面
self.dismatch_pixel += np.sum(self.dismatch) # Fa 虚警个数 像素的虚警
# print(self.dismatch_pixel)
self.all_pixel += size[0] * size[1]
self.PD += len(self.distance_match) # 如果中心点之间距离在3一下 就算Pd 所以Pd 是匹配上了的目标的个数
def get(self):
Final_FA = self.dismatch_pixel / self.all_pixel
Final_PD = self.PD / self.target
return Final_PD, float(Final_FA.cpu().detach().numpy())
def reset(self):
self.FA = np.zeros([self.bins + 1])
self.PD = np.zeros([self.bins + 1])
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
parser = argparse.ArgumentParser(description="PyTorch BasicIRSTD test")
parser.add_argument('--ROC_thr', type=int, default=10, help='num')
parser.add_argument("--model_names", default=['SCTrans'], type=list,
help="model_name: 'ACM', 'Ours01', 'DNANet', 'ISNet', 'ACMNet', 'Ours01', 'ISTDU-Net', 'U-Net', 'RISTDnet'")
parser.add_argument("--pth_dirs", default=['SIRST3/SCTransNet_NUAA_NUDT_IRSTD1K.pth.tar'], type=list)
parser.add_argument("--dataset_dir", default=r'D:\05TGARS\Upload\datasets', type=str, help="train_dataset_dir")
parser.add_argument("--dataset_names", default=['NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K'], type=list,
help="dataset_name: 'NUAA-SIRST', 'NUDT-SIRST', 'IRSTD-1K', 'SIRST3', 'NUDT-SIRST-Sea'")
parser.add_argument("--img_norm_cfg", default=None, type=dict,
help="specific a img_norm_cfg, default=None (using img_norm_cfg values of each dataset)")
parser.add_argument("--save_img", default=False, type=bool, help="save image of or not")
parser.add_argument("--save_img_dir", type=str, default=r'D:\SCI\01_02_SCI\Result/',
help="path of saved image")
parser.add_argument("--save_log", type=str, default=r'D:\05TGARS\Upload\log/', help="path of saved .pth")
parser.add_argument("--threshold", type=float, default=0.5)
global opt
opt = parser.parse_args()
def test():
test_set = TestSetLoader(opt.dataset_dir, opt.train_dataset_name, opt.test_dataset_name, opt.img_norm_cfg)
test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)
# *************************固定阈值**********************
# 计算mIOU 完全OK
IOU = mIoU()
# 计算nIOU 完全OK
nIoU_metric = SamplewiseSigmoidMetric(nclass=1, score_thresh=0)
# 计算PD_FA 完全OK
eval_05 = PD_FA()
ROC_05 = ROCMetric05(nclass=1, bins=10)
config_vit = config.get_SCTrans_config()
# CPU
net = SCTransNet(config_vit, mode='test', deepsuper=True)
state_dict = torch.load(opt.pth_dir, map_location='cpu')
# # CUDA
# net = SCTransNet(config_vit, mode='test', deepsuper=True).cuda()
# state_dict = torch.load(opt.pth_dir)
new_state_dict = OrderedDict()
#
for k, v in state_dict['state_dict'].items():
name = k[6:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。
net.load_state_dict(new_state_dict)
net.eval()
tbar = tqdm(test_loader)
with torch.no_grad():
for idx_iter, (img, gt_mask, size, img_dir) in enumerate(tbar):
# img = Variable(img)
# CPU
pred = net.forward(img)
pred = pred[:, :, :size[0], :size[1]]
gt_mask = gt_mask[:, :, :size[0], :size[1]]
# # CUDA:
# pred = net.forward(img).cuda()
# pred = pred[:, :, :size[0], :size[1]].cuda()
# gt_mask = gt_mask[:, :, :size[0], :size[1]].cuda()
# Fix threshold ##########################################################
# IOU
IOU.update((pred > 0.5), gt_mask) # 像素
# nIOU
nIoU_metric.update(pred, gt_mask) # 像素
eval_05.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size) # 目标
ROC_05.update(pred, gt_mask)
# save img
if opt.save_img == True:
img_save = transforms.ToPILImage()((pred[0, 0, :, :]).cpu())
if not os.path.exists(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name):
os.makedirs(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name)
img_save.save(opt.save_img_dir + opt.test_dataset_name + '/' + opt.model_name + '/' + img_dir[0] + '.png')
# 0.5
# IOU OK Good!
pixAcc, mIOU = IOU.get()
# # nIOU OK Good!
nIoU = nIoU_metric.get()
# # Pd Fa
results2 = eval_05.get()
#
# # FP
ture_positive_rate, false_positive_rate, recall, precision, FP, F1_score = ROC_05.get()
print('pixAcc: %.4f| mIoU: %.4f | nIoU: %.4f | Pd: %.4f| Fa: %.4f |F1: %.4f'
% (pixAcc * 100, mIOU * 100, nIoU * 100, results2[0] * 100, results2[1] * 1e+6, F1_score * 100))
if __name__ == '__main__':
opt.f = open(opt.save_log + 'test_' + (time.ctime()).replace(' ', '_').replace(':', '_') + '.txt', 'w')
if opt.pth_dirs == None:
for i in range(len(opt.model_names)):
opt.model_name = opt.model_names[i]
print(opt.model_name)
opt.f.write(opt.model_name + '_400.pth.tar' + '\n')
for dataset_name in opt.dataset_names:
opt.dataset_name = dataset_name
opt.train_dataset_name = opt.dataset_name
opt.test_dataset_name = opt.dataset_name
print(dataset_name)
opt.f.write(opt.dataset_name + '\n')
opt.pth_dir = opt.save_log + opt.dataset_name + '/' + opt.model_name + '_400.pth.tar'
test()
print('\n')
opt.f.write('\n')
opt.f.close()
else:
for model_name in opt.model_names:
for dataset_name in opt.dataset_names:
for pth_dir in opt.pth_dirs:
# if dataset_name in pth_dir and model_name in pth_dir:
opt.test_dataset_name = dataset_name
opt.model_name = model_name
opt.train_dataset_name = pth_dir.split('/')[0]
print(pth_dir)
opt.f.write(pth_dir)
print(opt.test_dataset_name)
opt.f.write(opt.test_dataset_name + '\n')
opt.pth_dir = opt.save_log + pth_dir
test()
print('\n')
opt.f.write('\n')
opt.f.close()
================================================
FILE: train.py
================================================
import argparse
import time
import os
import cv2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from torch.autograd import Variable
from torch.utils.data import DataLoader
from dataset import *
from metrics import *
from utils import *
import model.Config as config
from torch.utils.tensorboard import SummaryWriter
from model.SCTransNet import SCTransNet as SCTransNet
parser = argparse.ArgumentParser(description="PyTorch BasicIRSTD train")
parser.add_argument("--model_names", default=['SCTransNet'], type=list, help="'ACM', 'ALCNet', 'DNANet', 'ISNet', 'UIUNet', 'RDIAN', 'RISTDnet'")
parser.add_argument("--dataset_names", default=['SIRST3'], type=list)
# SIRST3: NUAA NUDT IRSTD-1K
parser.add_argument("--optimizer_name", default='Adam', type=str, help="optimizer name: AdamW, Adam, Adagrad, SGD")
parser.add_argument("--epochs", default=1000, type=int, help="optimizer name: AdamW, Adam, Adagrad, SGD")
parser.add_argument("--begin_test", default=500, type=int)
parser.add_argument("--every_test", default=1, type=int)
parser.add_argument("--every_save_pth", default=1000, type=int)
parser.add_argument("--every_print", default=10, type=int)
parser.add_argument("--dataset_dir", default=r'./datasets')
parser.add_argument("--batchSize", type=int, default=16, help="Training batch sizse")
parser.add_argument("--patchSize", type=int, default=256, help="Training patch size")
parser.add_argument("--save", default=r'./log', type=str, help="Save path of checkpoints")
parser.add_argument("--log_dir", type=str, default="./otherlogs/SCTransNet", help='path of log files')
parser.add_argument("--img_norm_cfg", default=None, type=dict)
parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use")
parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for test")
parser.add_argument("--seed", type=int, default=42, help="Threshold for test")
parser.add_argument("--resume", default=False, type=list, help="Resume from exisiting checkpoints (default: None)")
global opt
opt = parser.parse_args()
seed_pytorch(opt.seed)
config_vit = config.get_SCTrans_config()
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def train():
train_set = TrainSetLoader(dataset_dir=opt.dataset_dir, dataset_name=opt.dataset_name, patch_size=opt.patchSize,
img_norm_cfg=opt.img_norm_cfg)
train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
net = Net(model_name=opt.model_name, mode='train').cuda()
net.apply(weights_init_kaiming)
net.train()
epoch_state = 0
total_loss_list = []
total_loss_epoch = []
if not os.path.exists(opt.log_dir):
os.makedirs(opt.log_dir)
writer = SummaryWriter(opt.log_dir)
if opt.resume:
# for resume_pth in opt.resume:
# if opt.dataset_name in resume_pth and opt.model_name in resume_pth:
ckpt = torch.load('XX\\UCT04_best.pth.tar')
net.load_state_dict(ckpt['state_dict'])
epoch_state = ckpt['epoch']
total_loss_list = ckpt['total_loss']
# for i in range(len(opt.scheduler_settings['step'])):
# opt.scheduler_settings['step'][i] = opt.scheduler_settings['step'][i] - ckpt['epoch']
### Default settings of SCTransNet
if opt.optimizer_name == 'Adam':
opt.optimizer_settings = {'lr': 0.001}
opt.scheduler_name = 'CosineAnnealingLR'
opt.scheduler_settings = {'epochs': opt.epochs, 'eta_min': 1e-5, 'last_epoch': -1}
### Default settings of DNANet
if opt.optimizer_name == 'Adagrad':
opt.optimizer_settings = {'lr': 0.05}
opt.scheduler_name = 'CosineAnnealingLR'
opt.scheduler_settings = {'epochs': opt.epochs, 'min_lr': 1e-5}
### Default settings of EGEUNet
if opt.optimizer_name == 'AdamW':
opt.optimizer_settings = {'lr': 0.001, 'betas': (0.9, 0.999), "eps": 1e-8, "weight_decay": 1e-2,
"amsgrad": False}
opt.scheduler_name = 'CosineAnnealingLR'
opt.scheduler_settings = {'epochs': opt.epochs, 'T_max': 50, 'eta_min': 1e-5, 'last_epoch': -1}
opt.nEpochs = opt.scheduler_settings['epochs']
optimizer, scheduler = get_optimizer(net, opt.optimizer_name, opt.scheduler_name, opt.optimizer_settings,
opt.scheduler_settings)
for idx_epoch in range(epoch_state, opt.nEpochs):
net.train()
results1 = (0, 0)
results2 = (0, 0)
for idx_iter, (img, gt_mask) in enumerate(train_loader):
img, gt_mask = Variable(img).cuda(), Variable(gt_mask).cuda()
if img.shape[0] == 1:
continue
preds = net.forward(img)
loss = net.loss(preds, gt_mask)
total_loss_epoch.append(loss.detach().cpu())
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if (idx_epoch + 1) % opt.every_print == 0:
total_loss_list.append(float(np.array(total_loss_epoch).mean()))
print(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f, lr---%f,'
% (idx_epoch + 1, total_loss_list[-1], scheduler.get_last_lr()[0]))
opt.f.write(time.ctime()[4:-5] + ' Epoch---%d, total_loss---%f,\n'
% (idx_epoch + 1, total_loss_list[-1]))
total_loss_epoch = []
# Log the scalar values
writer.add_scalar('loss', total_loss_list[-1], idx_epoch + 1)
writer.add_scalar('lr', scheduler.get_last_lr()[0], idx_epoch + 1)
# 500
if (idx_epoch + 1) >= opt.begin_test and (idx_epoch + 1) % opt.every_test == 0:
test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg)
test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)
net.eval()
with torch.no_grad():
eval_mIoU = mIoU()
eval_PD_FA = PD_FA()
test_loss = []
for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader):
img = Variable(img).cuda()
pred = net.forward(img)
if isinstance(pred, tuple):
pred = pred[-1]
elif isinstance(pred, list):
pred = pred[-1]
else:
pred = pred
pred = pred[:, :, :size[0], :size[1]]
gt_mask = gt_mask[:, :, :size[0], :size[1]]
# if pred.size() != gt_mask.size():
# print('1111')
loss = net.loss(pred, gt_mask.cuda())
test_loss.append(loss.detach().cpu())
eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu())
eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size)
test_loss.append(float(np.array(test_loss).mean()))
results1 = eval_mIoU.get()
results2 = eval_PD_FA.get()
writer.add_scalar('mIOU', results1[-1], idx_epoch + 1)
writer.add_scalar('testloss', test_loss[-1], idx_epoch + 1)
if (idx_epoch + 1) % opt.every_save_pth == 0:
save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar'
save_checkpoint({
'epoch': idx_epoch + 1,
'state_dict': net.state_dict(),
'total_loss': total_loss_list,
}, save_pth)
test(save_pth)
if idx_epoch == 0:
best_mIOU = results1
best_Pd = results2
if results1[1] > best_mIOU[1]:
best_mIOU = results1
best_Pd = results2
print('------save the best model epoch', opt.model_name,'_%d ------' % (idx_epoch + 1))
opt.f.write("the best model epoch \t" + str(idx_epoch + 1) + '\n')
print("pixAcc, mIoU:\t" + str(best_mIOU))
print("testloss:\t" + str(test_loss[-1]))
print("PD, FA:\t" + str(best_Pd))
opt.f.write("pixAcc, mIoU:\t" + str(best_mIOU) + '\n')
opt.f.write("PD, FA:\t" + str(best_Pd) + '\n')
save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '_' + 'best' + '.pth.tar'
save_checkpoint({
'epoch': idx_epoch + 1,
'state_dict': net.state_dict(),
'total_loss': total_loss_list,
}, save_pth)
# last epoch
if (idx_epoch + 1) == opt.nEpochs and (idx_epoch + 1) % opt.every_save_pth != 0:
save_pth = opt.save + '/' + opt.dataset_name + '/' + opt.model_name + '_' + str(idx_epoch + 1) + '.pth.tar'
save_checkpoint({
'epoch': idx_epoch + 1,
'state_dict': net.state_dict(),
'total_loss': total_loss_list,
}, save_pth)
test(save_pth)
def test(save_pth):
test_set = TestSetLoader(opt.dataset_dir, opt.dataset_name, opt.dataset_name, img_norm_cfg=opt.img_norm_cfg)
test_loader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)
net = Net(model_name=opt.model_name, mode='test').cuda()
ckpt = torch.load(save_pth)
net.load_state_dict(ckpt['state_dict'])
net.eval()
with torch.no_grad():
eval_mIoU = mIoU()
eval_PD_FA = PD_FA()
test_loss_a = []
for idx_iter, (img, gt_mask, size, _) in enumerate(test_loader):
img = Variable(img).cuda()
pred = net.forward(img)
if pred.size() != gt_mask.size():
print('1111')
pred = pred[:, :, :size[0], :size[1]]
gt_mask = gt_mask[:, :, :size[0], :size[1]]
loss = net.loss(pred, gt_mask.cuda())
test_loss_a.append(loss.detach().cpu())
eval_mIoU.update((pred > opt.threshold).cpu(), gt_mask.cpu())
eval_PD_FA.update((pred[0, 0, :, :] > opt.threshold).cpu(), gt_mask[0, 0, :, :], size)
test_loss_a.append(float(np.array(test_loss_a).mean()))
results1 = eval_mIoU.get()
results2 = eval_PD_FA.get()
print('== == == == == == == ', opt.model_name, ' == == == == == == ==')
print("pixAcc, mIoU:\t" + str(results1))
print("testloss:\t" + str(test_loss_a[-1]))
print("PD, FA:\t" + str(results2))
opt.f.write("pixAcc, mIoU:\t" + str(results1) + '\n')
opt.f.write("PD, FA:\t" + str(results2) + '\n')
def save_checkpoint(state, save_path):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(state, save_path)
return save_path
class Net(nn.Module):
def __init__(self, model_name, mode):
super(Net, self).__init__()
self.model_name = model_name
# ************************************************loss*************************************************#
self.cal_loss = nn.BCELoss(size_average=True)
if model_name == 'SCTransNet':
if mode == 'train':
self.model = SCTransNet(config_vit, mode='train', deepsuper=True)
else:
self.model = SCTransNet(config_vit, mode='test', deepsuper=True)
def forward(self, img):
return self.model(img)
def loss(self, preds, gt_masks):
if isinstance(preds, list):
loss_total = 0
for i in range(len(preds)):
pred = preds[i]
gt_mask = gt_masks[i]
loss = self.cal_loss(pred, gt_mask)
loss_total = loss_total + loss
return loss_total / len(preds)
elif isinstance(preds, tuple):
a = []
for i in range(len(preds)):
pred = preds[i]
loss = self.cal_loss(pred, gt_masks)
a.append(loss)
loss_total = a[0] + a[1] + a[2] + a[3] + a[4] + a[5]
return loss_total
else:
loss = self.cal_loss(preds, gt_masks)
return loss
if __name__ == '__main__':
for dataset_name in opt.dataset_names:
opt.dataset_name = dataset_name
for model_name in opt.model_names:
opt.model_name = model_name
if not os.path.exists(opt.save):
os.makedirs(opt.save)
opt.f = open(opt.save + '/' + opt.dataset_name + '_' + opt.model_name + '_' + (time.ctime()).replace(' ',
'_').replace(
':', '_') + '.txt', 'w')
print(opt.dataset_name + '\t' + opt.model_name)
train()
print('\n')
opt.f.close()
================================================
FILE: utils.py
================================================
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import random
import matplotlib.pyplot as plt
import cv2
import numpy as np
import os
import math
import torch.nn as nn
from skimage import measure
from warmup_scheduler import GradualWarmupScheduler
import torch.nn.functional as F
import os
from torch.nn import init
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
def seed_pytorch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 and classname.find('SplAtConv2d') == -1:
init.xavier_normal(m.weight.data)
# def weights_init_xavier(m):
# classname = m.__class__.__name__
# if classname.find('Conv2d') != -1:
# # init.kaiming_normal_(m.weight.data,a=0, mode='fan_in', nonlinearity='leaky_relu')
# init.xavier_normal(m.weight.data)
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda()
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda()
def forward(self, x):
x0 = x[:, 0]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
return x0
def random_crop(img, mask, patch_size, pos_prob=None):
h, w = img.shape
if min(h, w) < patch_size:
img = np.pad(img, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)),
mode='constant') # 将不足 256的一边填充至256
mask = np.pad(mask, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)),
mode='constant') # label 与image 进行相同的变换
h, w = img.shape
while 1:
h_start = random.randint(0, h - patch_size)
h_end = h_start + patch_size
w_start = random.randint(0, w - patch_size)
w_end = w_start + patch_size
img_patch = img[h_start:h_end, w_start:w_end]
mask_patch = mask[h_start:h_end, w_start:w_end]
if pos_prob == None or random.random() > pos_prob:
break
elif mask_patch.sum() > 0:
break
return img_patch, mask_patch
def Normalized(img, img_norm_cfg):
return (img - img_norm_cfg['mean']) / img_norm_cfg['std']
def Denormalization(img, img_norm_cfg):
return img * img_norm_cfg['std'] + img_norm_cfg['mean']
def get_img_norm_cfg(dataset_name, dataset_dir):
if dataset_name == 'NUAA-SIRST':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'NUDT-SIRST':
img_norm_cfg = dict(mean=107.80905151367188, std=33.02274703979492)
elif dataset_name == 'IRSTD-1K':
img_norm_cfg = dict(mean=87.4661865234375, std=39.71953201293945)
elif dataset_name == 'SIRST2':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'SIRST3':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'NUDT-SIRST-Sea':
img_norm_cfg = dict(mean=43.62403869628906, std=18.91838264465332)
elif dataset_name == 'SIRST4':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'SIRST5':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'SIRST6':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'SIRST7':
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif dataset_name == 'IRDST-real':
img_norm_cfg = {'mean': 101.54053497314453, 'std': 56.49856185913086}
else:
with open(dataset_dir + '/' + dataset_name + '/img_idx/train_' + dataset_name + '.txt', 'r') as f:
train_list = f.read().splitlines()
with open(dataset_dir + '/' + dataset_name + '/img_idx/test_' + dataset_name + '.txt', 'r') as f:
test_list = f.read().splitlines()
img_list = train_list + test_list
img_dir = dataset_dir + '/' + dataset_name + '/images/'
mean_list = []
std_list = []
for img_pth in img_list:
try:
img = Image.open((img_dir + img_pth).replace('//', '/') + '.png').convert('I')
except:
try:
img = Image.open((img_dir + img_pth).replace('//', '/') + '.jpg').convert('I')
except:
img = Image.open((img_dir + img_pth).replace('//', '/') + '.bmp').convert('I')
img = np.array(img, dtype=np.float32)
mean_list.append(img.mean())
std_list.append(img.std())
img_norm_cfg = dict(mean=float(np.array(mean_list).mean()), std=float(np.array(std_list).mean()))
return img_norm_cfg
def get_optimizer(net, optimizer_name, scheduler_name, optimizer_settings, scheduler_settings):
if optimizer_name == 'Adam':
optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'])
if optimizer_name == 'Adamweight':
optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'], weight_decay=1e-3)
elif optimizer_name == 'Adagrad':
optimizer = torch.optim.Adagrad(net.parameters(), lr=optimizer_settings['lr'])
elif optimizer_name == 'SGD':
optimizer = torch.optim.SGD(net.parameters(), lr=optimizer_settings['lr'],
momentum=0.9,
weight_decay=scheduler_settings['weight_decay'])
# elif optimizer_name == 'AdamW':
# optimizer = torch.optim.AdamW(net.parameters(), lr=optimizer_settings['lr'], betas=optimizer_settings['betas'],
# eps=optimizer_settings['eps'], weight_decay=optimizer_settings['weight_decay'],
# amsgrad=optimizer_settings['amsgrad'])
if scheduler_name == 'MultiStepLR':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_settings['step'],
gamma=scheduler_settings['gamma'])
# elif scheduler_name == 'DNACosineAnnealingLR':
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'],
# eta_min=scheduler_settings['eta_min'])
elif scheduler_name == 'CosineAnnealingLR':
warmup_epochs = 10
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,
eta_min=scheduler_settings['eta_min'])
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
after_scheduler=scheduler_cosine)
elif scheduler_name == 'CosineAnnealingLRw50':
warmup_epochs = 50
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,
eta_min=scheduler_settings['eta_min'])
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
after_scheduler=scheduler_cosine)
elif scheduler_name == 'CosineAnnealingLRw0':
# warmup_epochs = 0
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min'])
# scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'] - warmup_epochs,
# eta_min=1e-5)
# scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
# after_scheduler=scheduler_cosine)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['T_max'],
# eta_min=scheduler_settings['eta_min'],
# last_epoch=scheduler_settings['last_epoch'])
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['eta_min'])
return optimizer, scheduler
def PadImg(img, times=32):
h, w = img.shape
if not h % times == 0:
img = np.pad(img, ((0, (h // times + 1) * times - h), (0, 0)), mode='constant')
if not w % times == 0:
img = np.pad(img, ((0, 0), (0, (w // times + 1) * times - w)), mode='constant')
return img
================================================
FILE: warmup_scheduler.py
================================================
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
在optimizer中会设置一个基础学习率base lr,
当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler
当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler and (not self.finished):
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
# !这是很关键的一个环节,需要直接返回新的base-lr
return [base_lr for base_lr in self.after_scheduler.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
print('warmuping...')
if self.last_epoch <= self.total_epoch:
warmup_lr=None
if self.multiplier == 1.0:
warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics,epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)
gitextract_03y9khf8/ ├── README.md ├── dataset.py ├── datasets/ │ └── SIRST3/ │ └── img_idx/ │ ├── test_SIRST3.txt │ └── train_SIRST3.txt ├── metrics.py ├── model/ │ ├── Config.py │ └── SCTransNet.py ├── test.py ├── train.py ├── utils.py └── warmup_scheduler.py
SYMBOL INDEX (158 symbols across 8 files)
FILE: dataset.py
class TrainSetLoader (line 8) | class TrainSetLoader(Dataset):
method __init__ (line 9) | def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg...
method __getitem__ (line 22) | def __getitem__(self, idx):
method __len__ (line 54) | def __len__(self):
class TrainSetLoader02 (line 57) | class TrainSetLoader02(Dataset):
method __init__ (line 58) | def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg...
method __getitem__ (line 71) | def __getitem__(self, idx):
method __len__ (line 103) | def __len__(self):
class TrainSetLoader03 (line 107) | class TrainSetLoader03(Dataset):
method __init__ (line 108) | def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg...
method __getitem__ (line 121) | def __getitem__(self, idx):
method __len__ (line 153) | def __len__(self):
class TrainSetLoader04 (line 157) | class TrainSetLoader04(Dataset):
method __init__ (line 158) | def __init__(self, dataset_dir, dataset_name, patch_size, img_norm_cfg...
method __getitem__ (line 171) | def __getitem__(self, idx):
method __len__ (line 203) | def __len__(self):
class TestSetLoader (line 206) | class TestSetLoader(Dataset):
method __init__ (line 207) | def __init__(self, dataset_dir, train_dataset_name, test_dataset_name,...
method __getitem__ (line 218) | def __getitem__(self, idx):
method __len__ (line 246) | def __len__(self):
class EvalSetLoader (line 250) | class EvalSetLoader(Dataset):
method __init__ (line 251) | def __init__(self, dataset_dir, mask_pred_dir, test_dataset_name, mode...
method __getitem__ (line 260) | def __getitem__(self, idx):
method __len__ (line 279) | def __len__(self):
class augumentation (line 283) | class augumentation(object):
method __call__ (line 284) | def __call__(self, input, target):
FILE: metrics.py
class ROCMetric (line 6) | class ROCMetric():
method __init__ (line 10) | def __init__(self, nclass, bins):
method update (line 24) | def update(self, preds, labels):
method get (line 36) | def get(self):
method reset (line 45) | def reset(self):
class mIoU (line 53) | class mIoU():
method __init__ (line 55) | def __init__(self):
method update (line 59) | def update(self, preds, labels):
method get (line 67) | def get(self):
method reset (line 73) | def reset(self):
class PD_FA (line 80) | class PD_FA():
method __init__ (line 81) | def __init__(self, ):
method update (line 90) | def update(self, preds, labels, size):
method get (line 127) | def get(self):
method reset (line 132) | def reset(self):
function batch_pix_accuracy (line 137) | def batch_pix_accuracy(output, target):
function batch_intersection_union (line 153) | def batch_intersection_union(output, target):
FILE: model/Config.py
function get_SCTrans_config (line 17) | def get_SCTrans_config():
FILE: model/SCTransNet.py
function get_CTranS_config (line 22) | def get_CTranS_config():
class Channel_Embeddings (line 39) | class Channel_Embeddings(nn.Module):
method __init__ (line 40) | def __init__(self, config, patchsize, img_size, in_channels):
method forward (line 53) | def forward(self, x):
class Reconstruct (line 60) | class Reconstruct(nn.Module):
method __init__ (line 61) | def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
method forward (line 73) | def forward(self, x):
class Attention_org (line 86) | class Attention_org(nn.Module):
method __init__ (line 87) | def __init__(self, config, vis, channel_num):
method forward (line 148) | def forward(self, emb1, emb2, emb3, emb4, emb_all):
function to_3d (line 211) | def to_3d(x):
function to_4d (line 215) | def to_4d(x, h, w):
class BiasFree_LayerNorm (line 219) | class BiasFree_LayerNorm(nn.Module):
method __init__ (line 220) | def __init__(self, normalized_shape):
method forward (line 231) | def forward(self, x):
class WithBias_LayerNorm (line 236) | class WithBias_LayerNorm(nn.Module):
method __init__ (line 237) | def __init__(self, normalized_shape):
method forward (line 249) | def forward(self, x):
class LayerNorm3d (line 255) | class LayerNorm3d(nn.Module):
method __init__ (line 256) | def __init__(self, dim, LayerNorm_type):
method forward (line 263) | def forward(self, x):
class eca_layer_2d (line 267) | class eca_layer_2d(nn.Module):
method __init__ (line 268) | def __init__(self, channel, k_size=3):
method forward (line 279) | def forward(self, x):
class FeedForward (line 287) | class FeedForward(nn.Module):
method __init__ (line 288) | def __init__(self, dim, ffn_expansion_factor, bias):
method forward (line 304) | def forward(self, x):
class Block_ViT (line 315) | class Block_ViT(nn.Module):
method __init__ (line 316) | def __init__(self, config, vis, channel_num):
method forward (line 337) | def forward(self, emb1, emb2, emb3, emb4):
class Encoder (line 380) | class Encoder(nn.Module):
method __init__ (line 381) | def __init__(self, config, vis, channel_num):
method forward (line 393) | def forward(self, emb1, emb2, emb3, emb4):
class ChannelTransformer (line 406) | class ChannelTransformer(nn.Module):
method __init__ (line 407) | def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 5...
method forward (line 425) | def forward(self, en1, en2, en3, en4):
function get_activation (line 446) | def get_activation(activation_type):
function _make_nConv (line 454) | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
class CBN (line 463) | class CBN(nn.Module):
method __init__ (line 464) | def __init__(self, in_channels, out_channels, activation='ReLU'):
method forward (line 471) | def forward(self, x):
class DownBlock (line 477) | class DownBlock(nn.Module):
method __init__ (line 478) | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReL...
method forward (line 483) | def forward(self, x):
class Flatten (line 488) | class Flatten(nn.Module):
method forward (line 489) | def forward(self, x):
class CCA (line 493) | class CCA(nn.Module):
method __init__ (line 494) | def __init__(self, F_g, F_x):
method forward (line 504) | def forward(self, g, x):
class UpBlock_attention (line 516) | class UpBlock_attention(nn.Module):
method __init__ (line 517) | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReL...
method forward (line 523) | def forward(self, x, skip_x):
class Res_block (line 530) | class Res_block(nn.Module):
method __init__ (line 531) | def __init__(self, in_channels, out_channels, stride=1):
method forward (line 546) | def forward(self, x):
class SCTransNet (line 561) | class SCTransNet(nn.Module):
method __init__ (line 562) | def __init__(self, config, n_channels=1, n_classes=1, img_size=256, vi...
method _make_layer (line 594) | def _make_layer(self, block, input_channels, output_channels, num_bloc...
method forward (line 601) | def forward(self, x):
FILE: test.py
function cal_tp_pos_fp_neg (line 17) | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh):
class SamplewiseSigmoidMetric (line 44) | class SamplewiseSigmoidMetric(object):
method __init__ (line 48) | def __init__(self, nclass, score_thresh=0.5):
method update (line 54) | def update(self, preds, labels):
method get (line 85) | def get(self):
method reset (line 97) | def reset(self):
function batch_intersection_union_n (line 105) | def batch_intersection_union_n(output, target, nclass, score_thresh):
class ROCMetric05 (line 149) | class ROCMetric05():
method __init__ (line 153) | def __init__(self, nclass, bins):
method update (line 167) | def update(self, preds, labels):
method get (line 179) | def get(self):
method reset (line 189) | def reset(self):
class mIoU (line 197) | class mIoU():
method __init__ (line 199) | def __init__(self):
method update (line 203) | def update(self, preds, labels):
method get (line 211) | def get(self):
method reset (line 217) | def reset(self):
class PDFA (line 224) | class PDFA():
method __init__ (line 225) | def __init__(self, ):
method update (line 234) | def update(self, preds, labels, size):
method get (line 271) | def get(self):
method reset (line 276) | def reset(self):
function batch_pix_accuracy (line 281) | def batch_pix_accuracy(output, target):
function batch_intersection_union (line 297) | def batch_intersection_union(output, target):
class PD_FA (line 320) | class PD_FA():
method __init__ (line 321) | def __init__(self, ):
method update (line 330) | def update(self, preds, labels, size):
method get (line 369) | def get(self):
method reset (line 374) | def reset(self):
function test (line 402) | def test():
FILE: train.py
function weights_init_kaiming (line 45) | def weights_init_kaiming(m):
function train (line 56) | def train():
function test (line 210) | def test(save_pth):
function save_checkpoint (line 246) | def save_checkpoint(state, save_path):
class Net (line 253) | class Net(nn.Module):
method __init__ (line 254) | def __init__(self, model_name, mode):
method forward (line 264) | def forward(self, img):
method loss (line 267) | def loss(self, preds, gt_masks):
FILE: utils.py
function seed_pytorch (line 22) | def seed_pytorch(seed=42):
function weights_init_xavier (line 31) | def weights_init_xavier(m):
function weights_init_kaiming (line 44) | def weights_init_kaiming(m):
class Get_gradient_nopadding (line 55) | class Get_gradient_nopadding(nn.Module):
method __init__ (line 56) | def __init__(self):
method forward (line 69) | def forward(self, x):
function random_crop (line 79) | def random_crop(img, mask, patch_size, pos_prob=None):
function Normalized (line 105) | def Normalized(img, img_norm_cfg):
function Denormalization (line 109) | def Denormalization(img, img_norm_cfg):
function get_img_norm_cfg (line 113) | def get_img_norm_cfg(dataset_name, dataset_dir):
function get_optimizer (line 160) | def get_optimizer(net, optimizer_name, scheduler_name, optimizer_setting...
function PadImg (line 212) | def PadImg(img, times=32):
FILE: warmup_scheduler.py
class GradualWarmupScheduler (line 4) | class GradualWarmupScheduler(_LRScheduler):
method __init__ (line 17) | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler...
method get_lr (line 26) | def get_lr(self):
method step_ReduceLROnPlateau (line 38) | def step_ReduceLROnPlateau(self, metrics, epoch=None):
method step (line 57) | def step(self, epoch=None, metrics=None):
Condensed preview — 11 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (139K chars).
[
{
"path": "README.md",
"chars": 6816,
"preview": "# SCTransNet: Spatial-channel Cross Transformer Network for Infrared Small Target Detection [[Paper]](https://ieeexplore"
},
{
"path": "dataset.py",
"chars": 13821,
"preview": "from utils import *\r\nimport matplotlib.pyplot as plt\r\nimport os\r\n\r\nos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\r\n\r\n\r\nclas"
},
{
"path": "datasets/SIRST3/img_idx/test_SIRST3.txt",
"chars": 8980,
"preview": "Misc_338\r\nMisc_379\r\nMisc_422\r\nMisc_73\r\nMisc_321\r\nMisc_162\r\nMisc_372\r\nMisc_185\r\nMisc_420\r\nMisc_143\r\nMisc_137\r\nMisc_224\r\nM"
},
{
"path": "datasets/SIRST3/img_idx/train_SIRST3.txt",
"chars": 13697,
"preview": "Misc_119\r\nMisc_64\r\nMisc_90\r\nMisc_364\r\nMisc_250\r\nMisc_351\r\nMisc_39\r\nMisc_313\r\nMisc_179\r\nMisc_344\r\nMisc_421\r\nMisc_398\r\nMis"
},
{
"path": "metrics.py",
"chars": 6621,
"preview": "import numpy as np\r\nimport torch\r\nfrom skimage import measure\r\n\r\n\r\nclass ROCMetric():\r\n \"\"\"Computes pixAcc and mIoU m"
},
{
"path": "model/Config.py",
"chars": 904,
"preview": "# -*- coding: utf-8 -*-\n# @Author : Shuai Yuan\n# @File : Config.py\n# @Software: PyCharm\n# coding=utf-8\nimport os\nimp"
},
{
"path": "model/SCTransNet.py",
"chars": 28653,
"preview": "# -*- coding: utf-8 -*-\n# -*- coding: utf-8 -*-\n# @Author : Shuai Yuan\n# @File : SCTransNet.py\n# @Software: PyCharm\n"
},
{
"path": "test.py",
"chars": 19981,
"preview": "import argparse\r\nfrom torch.autograd import Variable\r\nfrom torch.utils.data import DataLoader\r\nfrom tqdm import tqdm\r\nim"
},
{
"path": "train.py",
"chars": 13707,
"preview": "import argparse\r\nimport time\r\nimport os\r\nimport cv2\r\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\r\nos.environ[\"CUDA_VI"
},
{
"path": "utils.py",
"chars": 10171,
"preview": "import torch\r\nimport numpy as np\r\nfrom PIL import Image\r\nfrom torchvision import transforms\r\nfrom torch.utils.data.datas"
},
{
"path": "warmup_scheduler.py",
"chars": 3610,
"preview": "from torch.optim.lr_scheduler import _LRScheduler\r\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\r\n\r\nclass Gradu"
}
]
About this extraction
This page contains the full source code of the xdFai/SCTransNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 11 files (124.0 KB), approximately 37.8k tokens, and a symbol index with 158 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.