Repository: Yanfeng-Zhou/XNet
Branch: main
Commit: 58181894053f
Files: 108
Total size: 1.1 MB
Directory structure:
gitextract_k1uwhl9g/
├── .idea/
│ ├── XNet.iml
│ ├── deployment.xml
│ ├── inspectionProfiles/
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── vcs.xml
│ └── workspace.xml
├── LICENSE
├── README.md
├── config/
│ ├── __init__.py
│ ├── augmentation/
│ │ ├── __init__.py
│ │ └── online_aug.py
│ ├── dataset_config/
│ │ ├── __init__.py
│ │ └── dataset_cfg.py
│ ├── eval_config/
│ │ ├── __init__.py
│ │ └── eval.py
│ ├── ramps/
│ │ ├── __init__.py
│ │ └── ramps.py
│ ├── train_test_config/
│ │ ├── __init__.py
│ │ └── train_test_config.py
│ ├── visdom_config/
│ │ ├── __init__.py
│ │ └── visual_visdom.py
│ └── warmup_config/
│ ├── __init__.py
│ └── warmup.py
├── dataload/
│ ├── __init__.py
│ ├── dataset_2d.py
│ └── dataset_3d.py
├── loss/
│ ├── __init__.py
│ └── loss_function.py
├── models/
│ ├── __init__.py
│ ├── getnetwork.py
│ ├── networks_2d/
│ │ ├── __init__.py
│ │ ├── aerial_lanenet.py
│ │ ├── hrnet.py
│ │ ├── mwcnn.py
│ │ ├── resunet.py
│ │ ├── resunet_plusplus.py
│ │ ├── swinunet.py
│ │ ├── u2net.py
│ │ ├── unet.py
│ │ ├── unet_3plus.py
│ │ ├── unet_cct.py
│ │ ├── unet_plusplus.py
│ │ ├── unet_urpc.py
│ │ ├── wavesnet.py
│ │ ├── wds.py
│ │ └── xnet.py
│ └── networks_3d/
│ ├── __init__.py
│ ├── conresnet.py
│ ├── cotr.py
│ ├── dmfnet.py
│ ├── espnet3d.py
│ ├── res_unet3d.py
│ ├── transbts.py
│ ├── unet3d.py
│ ├── unet3d_cct.py
│ ├── unet3d_dtc.py
│ ├── unet3d_urpc.py
│ ├── unetr.py
│ ├── vnet.py
│ ├── vnet_cct.py
│ ├── vnet_dtc.py
│ └── xnet3d.py
├── requirements.txt
├── test.py
├── test_3d.py
├── test_ConResNet.py
├── test_DTC.py
├── test_xnet.py
├── test_xnet3d.py
├── tools/
│ ├── Atrial/
│ │ ├── __init__.py
│ │ ├── postprocess.py
│ │ └── preprocess.py
│ ├── LiTS/
│ │ ├── __init__.py
│ │ ├── postprocess.py
│ │ ├── preprocess.py
│ │ └── split_train_val.py
│ ├── __init__.py
│ ├── eval.py
│ ├── mask2sdf.py
│ ├── res_image_mask.py
│ ├── wavelet2D.py
│ └── wavelet3D.py
├── train_semi_CCT.py
├── train_semi_CCT_3d.py
├── train_semi_CPS.py
├── train_semi_CPS_3d.py
├── train_semi_CT.py
├── train_semi_CT_3d.py
├── train_semi_DTC.py
├── train_semi_EM.py
├── train_semi_EM_3d.py
├── train_semi_MT.py
├── train_semi_MT_3d.py
├── train_semi_UAMT.py
├── train_semi_UAMT_3d.py
├── train_semi_URPC.py
├── train_semi_URPC_3d.py
├── train_semi_XNet.py
├── train_semi_XNet3d.py
├── train_sup.py
├── train_sup_3d.py
├── train_sup_ConResNet.py
├── train_sup_XNet.py
├── train_sup_XNet3d.py
├── train_sup_XNet_sb.py
├── train_sup_alnet.py
└── train_sup_wds.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .idea/XNet.iml
================================================
================================================
FILE: .idea/deployment.xml
================================================
================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
================================================
FILE: .idea/misc.xml
================================================
================================================
FILE: .idea/modules.xml
================================================
================================================
FILE: .idea/vcs.xml
================================================
================================================
FILE: .idea/workspace.xml
================================================
1655015922020
1655015922020
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Yanfeng Zhou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images
This is the official code of [XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images](https://openaccess.thecvf.com/content/ICCV2023/html/Zhou_XNet_Wavelet-Based_Low_and_High_Frequency_Fusion_Networks_for_Fully-_ICCV_2023_paper.html) (ICCV 2023).
## Overview
Architecture of XNet.
Visualize dual-branch inputs. (a) Raw image. (b) Wavelet transform results. (c) Low frequency image. (d) High frequency image.
Architecture of LF and HF fusion module.
## Quantitative Comparison
Comparison with fully- and semi-supervised state-of-the-art models on GlaS and CREMI test set. Semi-supervised models are based on UNet. DS indicates deep supervision. * indicates lightweight models. ‡ indicates training for 1000 epochs. - indicates training failed. **Red** and **bold** indicate the best and second best performance.
Comparison with fully- and semi-supervised state-of-the-art models on LA and LiTS test set. Due to GPU memory limitations, some semi-supervised models using smaller architectures, ✝ and * indicate models are based on lightweight 3D UNet (half of channels) and VNet, respectively. ‡ indicates training for 1000 epochs. - indicates training failed. **Red** and **bold** indicate the best and second best performance.
## Qualitative Comparison
Qualitative results on GIaS, CREMI, LA and LiTS. (a) Raw images. (b) Ground truth. (c) MT. (d) Semi-supervised XNet (3D XNet). (e) UNet (3D UNet). (f) Fully-Supervised XNet (3D XNet). The orange arrows highlight the difference among of the results.
## Reimplemented Architecture
We have reimplemented some 2D and 3D models in semi- and supervised semantic segmentation.
## Requirements
```
albumentations==0.5.2
einops==0.4.1
MedPy==0.4.0
numpy==1.20.2
opencv_python==4.2.0.34
opencv_python_headless==4.5.1.48
Pillow==8.0.0
PyWavelets==1.1.1
scikit_image==0.18.1
scikit_learn==1.0.1
scipy==1.4.1
SimpleITK==2.1.0
timm==0.6.7
torch==1.8.0+cu111
torchio==0.18.53
torchvision==0.9.0+cu111
tqdm==4.65.0
visdom==0.1.8.9
```
## Usage
**Data preparation**
Your datasets directory tree should be look like this:
>to see [tools/wavelet2D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet2D.py) and [tools/wavelet3D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet3D.py) for **L** and **H**
```
dataset
├── train_sup_100
├── L
├── 1.tif
├── 2.tif
└── ...
├── H
├── 1.tif
├── 2.tif
└── ...
└── mask
├── 1.tif
├── 2.tif
└── ...
├── train_sup_20
├── L
├── H
└── mask
├── train_unsup_80
└── L
├── H
└── val
├── L
├── H
└── mask
```
**Supervised training**
```
python -m torch.distributed.launch --nproc_per_node=4 train_sup_XNet.py
```
**Semi-supervised training**
```
python -m torch.distributed.launch --nproc_per_node=4 train_semi_XNet.py
```
**Testing**
```
python -m torch.distributed.launch --nproc_per_node=4 test.py
```
## Citation
If our work is useful for your research, please cite our paper:
```
@InProceedings{Zhou_2023_ICCV,
author = {Zhou, Yanfeng and Huang, Jiaxing and Wang, Chenlong and Song, Le and Yang, Ge},
title = {XNet: Wavelet-Based Low and High Frequency Fusion Networks for Fully- and Semi-Supervised Semantic Segmentation of Biomedical Images},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2023},
pages = {21085-21096}
}
```
================================================
FILE: config/__init__.py
================================================
================================================
FILE: config/augmentation/__init__.py
================================================
================================================
FILE: config/augmentation/online_aug.py
================================================
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchio import transforms as T
import torchio as tio
def data_transform_2d():
data_transforms = {
'train': A.Compose([
A.Resize(128, 128, p=1),
A.Flip(p=0.75),
A.Transpose(p=0.5),
A.RandomRotate90(p=1),
],
additional_targets={'image2': 'image', 'mask2': 'mask'}
),
'val': A.Compose([
A.Resize(128, 128, p=1),
],
additional_targets={'image2': 'image', 'mask2': 'mask'}
),
'test': A.Compose([
A.Resize(128, 128, p=1),
],
additional_targets={'image2': 'image', 'mask2': 'mask'}
)
}
return data_transforms
def data_normalize_2d(mean, std):
data_normalize = A.Compose([
A.Normalize(mean, std),
ToTensorV2()
],
additional_targets={'image2': 'image', 'mask2': 'mask'}
)
return data_normalize
def data_transform_aerial_lanenet(H, W):
data_transforms = A.Compose([
A.Resize(H, W, p=1),
ToTensorV2()
])
return data_transforms
def data_transform_3d(normalization):
data_transform = {
'train': T.Compose([
T.RandomFlip(),
T.RandomBiasField(coefficients=(0.12, 0.15), order=2, p=0.2),
T.OneOf({
T.RandomNoise(): 0.5,
T.RandomBlur(std=1): 0.5,
}, p=0.2),
T.ZNormalization(masking_method=normalization),
]),
'val': T.Compose([
# T.CropOrPad(pad_size),
T.ZNormalization(masking_method=normalization),
# T.Resize(target_shape=(512, 512, 512), p=1)
]),
'test': T.Compose([
# T.CropOrPad(pad_size),
T.ZNormalization(masking_method=normalization),
# T.Resize(target_shape=(512, 512, 512), p=1)
])
}
return data_transform
================================================
FILE: config/dataset_config/__init__.py
================================================
================================================
FILE: config/dataset_config/dataset_cfg.py
================================================
import numpy as np
import torchio as tio
def dataset_cfg(dataet_name):
config = {
'CREMI':
{
'IN_CHANNELS': 1,
'NUM_CLASSES': 2,
'MEAN': [0.503902],
'STD': [0.110739],
'MEAN_DB2_H': [0.505787],
'STD_DB2_H': [0.115504],
'PALETTE': list(np.array([
[255, 255, 255],
[0, 0, 0],
]).flatten())
},
'GlaS':
{
'IN_CHANNELS': 3,
'NUM_CLASSES': 2,
'MEAN': [0.787803, 0.512017, 0.784938],
'STD': [0.428206, 0.507778, 0.426366],
'MEAN_HAAR_H': [0.528318],
'STD_HAAR_H': [0.076766],
'MEAN_HAAR_L': [0.579144],
'STD_HAAR_L': [0.227451],
'MEAN_HAAR_HHL': [0.542428],
'STD_HAAR_HHL': [0.142663],
'MEAN_HAAR_HLL': [0.569150],
'STD_HAAR_HLL': [0.220854],
'MEAN_BIOR1.5_H': [0.525711],
'STD_BIOR1.5_H': [0.076606],
'MEAN_BIOR2.4_H': [0.516579],
'STD_BIOR2.4_H': [0.078798],
'MEAN_COIF1_H': [0.523858],
'STD_COIF1_H': [0.081001],
'MEAN_DB2_H': [0.505234],
'STD_DB2_H': [0.080919],
'MEAN_DMEY_H': [0.502698],
'STD_DMEY_H': [0.078861],
'PALETTE': list(np.array([
[0, 0, 0],
[255, 255, 255],
]).flatten())
},
'ISIC-2017':
{
'IN_CHANNELS': 3,
'NUM_CLASSES': 2,
'MEAN': [0.699002, 0.556046, 0.512134],
'STD': [0.365650, 0.317347, 0.339400],
'MEAN_DB2_H': [0.489676],
'STD_DB2_H': [0.081749],
'PALETTE': list(np.array([
[0, 0, 0],
[255, 255, 255],
]).flatten())
},
'LiTS':
{
'IN_CHANNELS': 1,
'NUM_CLASSES': 3,
'NORMALIZE': tio.ZNormalization.mean,
'PATCH_SIZE': (112, 112, 32),
'FORMAT': '.nii',
'NUM_SAMPLE_TRAIN': 8,
'NUM_SAMPLE_VAL': 12
},
'Atrial':
{
'IN_CHANNELS': 1,
'NUM_CLASSES': 2,
'NORMALIZE': tio.ZNormalization.mean,
'PATCH_SIZE': (96, 96, 80),
'FORMAT': '.nrrd',
'NUM_SAMPLE_TRAIN': 4,
'NUM_SAMPLE_VAL': 8
},
}
return config[dataet_name]
================================================
FILE: config/eval_config/__init__.py
================================================
================================================
FILE: config/eval_config/eval.py
================================================
import numpy as np
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff
import torch
def evaluate(y_scores, y_true, interval=0.02):
y_scores = torch.softmax(y_scores, dim=1)
y_scores = y_scores[:, 1, ...].cpu().detach().numpy().flatten()
y_true = y_true.data.cpu().numpy().flatten()
thresholds = np.arange(0, 0.9, interval)
jaccard = np.zeros(len(thresholds))
dice = np.zeros(len(thresholds))
y_true.astype(np.int8)
for indy in range(len(thresholds)):
threshold = thresholds[indy]
y_pred = (y_scores > threshold).astype(np.int8)
sum_area = (y_pred + y_true)
tp = float(np.sum(sum_area == 2))
union = np.sum(sum_area == 1)
jaccard[indy] = tp / float(union + tp)
dice[indy] = 2 * tp / float(union + 2 * tp)
thred_indx = np.argmax(jaccard)
m_jaccard = jaccard[thred_indx]
m_dice = dice[thred_indx]
return thresholds[thred_indx], m_jaccard, m_dice
def evaluate_multi(y_scores, y_true):
y_scores = torch.softmax(y_scores, dim=1)
y_pred = torch.max(y_scores, 1)[1]
y_pred = y_pred.data.cpu().numpy().flatten()
y_true = y_true.data.cpu().numpy().flatten()
hist = confusion_matrix(y_true, y_pred)
hist_diag = np.diag(hist)
hist_sum_0 = hist.sum(axis=0)
hist_sum_1 = hist.sum(axis=1)
jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag)
m_jaccard = np.nanmean(jaccard)
dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0)
m_dice = np.nanmean(dice)
return jaccard, m_jaccard, dice, m_dice
================================================
FILE: config/ramps/__init__.py
================================================
================================================
FILE: config/ramps/ramps.py
================================================
import numpy as np
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup"""
assert current >= 0 and rampup_length >= 0
if current >= rampup_length:
return 1.0
else:
return current / rampup_length
def cosine_rampdown(current, rampdown_length):
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
assert 0 <= current <= rampdown_length
return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
================================================
FILE: config/train_test_config/__init__.py
================================================
================================================
FILE: config/train_test_config/train_test_config.py
================================================
import numpy as np
from config.eval_config.eval import evaluate, evaluate_multi
import torch
import os
from PIL import Image
import torchio as tio
def print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus):
train_epoch_loss = train_loss / num_batches['train_sup']
print('-' * print_num)
print('| Train Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
return train_epoch_loss
def print_train_loss_MT(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_half, print_num_minus):
train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
train_epoch_loss = train_loss / num_batches['train_sup']
print('-' * print_num)
print('| Train Sup Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '|')
print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss
def print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus):
train_epoch_loss_seg = train_loss_seg / num_batches['train_sup']
train_epoch_loss_res = train_loss_res / num_batches['train_sup']
train_epoch_loss = train_loss / num_batches['train_sup']
print('-' * print_num)
print('| Train Seg Loss: {:.4f}'.format(train_epoch_loss_seg).ljust(print_num_half, ' '), '| Train Res Loss: {:.4f}'.format(train_epoch_loss_res).ljust(print_num_half, ' '), '|')
print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
return train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss
def print_train_loss_EM(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_minus):
train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
train_epoch_loss = train_loss / num_batches['train_sup']
print('-' * print_num)
print('| Train Sup Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_minus, ' '), '|')
print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_minus, ' '), '|')
print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss
def print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_cps, train_loss, num_batches, print_num, print_num_half):
train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
train_epoch_loss_sup2 = train_loss_sup_2 / num_batches['train_sup']
train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
train_epoch_loss = train_loss / num_batches['train_sup']
print('-' * print_num)
print('| Train Sup Loss 1: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train SUP Loss 2: {:.4f}'.format(train_epoch_loss_sup2).ljust(print_num_half, ' '), '|')
print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_half, ' '), '|')
print('-' * print_num)
return train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss
def print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus):
val_epoch_loss = val_loss / num_batches['val']
print('-' * print_num)
print('| Val Loss: {:.4f}'.format(val_epoch_loss).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
return val_epoch_loss
def print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half):
val_epoch_loss_sup1 = val_loss_sup_1 / num_batches['val']
val_epoch_loss_sup2 = val_loss_sup_2 / num_batches['val']
print('-' * print_num)
print('| Val Sup Loss 1: {:.4f}'.format(val_epoch_loss_sup1).ljust(print_num_half, ' '), '| Val Sup Loss 2: {:.4f}'.format(val_epoch_loss_sup2).ljust(print_num_half, ' '), '|')
print('-' * print_num)
return val_epoch_loss_sup1, val_epoch_loss_sup2
def print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half):
val_epoch_loss_seg = val_loss_seg / num_batches['val']
val_epoch_loss_res = val_loss_res / num_batches['val']
print('-' * print_num)
print('| Val Seg Loss: {:.4f}'.format(val_epoch_loss_seg).ljust(print_num_half, ' '), '| Val Res Loss: {:.4f}'.format(val_epoch_loss_res).ljust(print_num_half, ' '), '|')
print('-' * print_num)
return val_epoch_loss_seg, val_epoch_loss_res
def print_train_eval_sup(num_classes, score_list_train, mask_list_train, print_num):
if num_classes == 2:
eval_list = evaluate(score_list_train, mask_list_train)
print('| Train Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Train Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Train Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
train_m_jc = eval_list[1]
else:
eval_list = evaluate_multi(score_list_train, mask_list_train)
np.set_printoptions(precision=4, suppress=True)
print('| Train Jc: {}'.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Train Dc: {}'.format(eval_list[2]).ljust(print_num, ' '), '|')
print('| Train mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Train mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')
train_m_jc = eval_list[1]
return eval_list, train_m_jc
def print_train_eval_XNet(num_classes, score_list_train1, score_list_train2, mask_list_train, print_num):
if num_classes == 2:
eval_list1 = evaluate(score_list_train1, mask_list_train)
eval_list2 = evaluate(score_list_train2, mask_list_train)
print('| Train Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
print('| Train Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
print('| Train Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
train_m_jc1 = eval_list1[1]
train_m_jc2 = eval_list2[1]
else:
eval_list1 = evaluate_multi(score_list_train1, mask_list_train)
eval_list2 = evaluate_multi(score_list_train2, mask_list_train)
np.set_printoptions(precision=4, suppress=True)
print('| Train Jc 1: {}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
print('| Train Dc 1: {}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
print('| Train mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
print('| Train mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Train mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')
train_m_jc1 = eval_list1[1]
train_m_jc2 = eval_list2[1]
return eval_list1, eval_list2, train_m_jc1, train_m_jc2
def print_val_eval_sup(num_classes, score_list_val, mask_list_val, print_num):
if num_classes == 2:
eval_list = evaluate(score_list_val, mask_list_val)
print('| Val Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Val Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Val Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
val_m_jc = eval_list[1]
else:
eval_list = evaluate_multi(score_list_val, mask_list_val)
np.set_printoptions(precision=4, suppress=True)
print('| Val Jc: {} '.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Val Dc: {} '.format(eval_list[2]).ljust(print_num, ' '), '|')
print('| Val mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Val mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')
val_m_jc = eval_list[1]
return eval_list, val_m_jc
def print_val_eval(num_classes, score_list_val1, score_list_val2, mask_list_val, print_num):
if num_classes == 2:
eval_list1 = evaluate(score_list_val1, mask_list_val)
eval_list2 = evaluate(score_list_val2, mask_list_val)
print('| Val Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Val Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
print('| Val Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
print('| Val Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Val Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
val_m_jc1 = eval_list1[1]
val_m_jc2 = eval_list2[1]
else:
eval_list1 = evaluate_multi(score_list_val1, mask_list_val)
eval_list2 = evaluate_multi(score_list_val2, mask_list_val)
np.set_printoptions(precision=4, suppress=True)
print('| Val Jc 1: {} '.format(eval_list1[0]).ljust(print_num, ' '), '| Val Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
print('| Val Dc 1: {} '.format(eval_list1[2]).ljust(print_num, ' '), '| Val Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
print('| Val mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
print('| Val mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Val mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')
val_m_jc1 = eval_list1[1]
val_m_jc2 = eval_list2[1]
return eval_list1, eval_list2, val_m_jc1, val_m_jc2
def save_val_best_sup_2d(num_classes, best_list, model, score_list_val, name_list_val, eval_list, path_trained_model, path_seg_results, palette, model_name):
if num_classes == 2:
if best_list[1] < eval_list[1]:
best_list = eval_list
torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))
score_list_val = torch.softmax(score_list_val, dim=1)
pred_results = score_list_val[:, 1, :, :].cpu().numpy()
pred_results[pred_results > eval_list[0]] = 1
pred_results[pred_results <= eval_list[0]] = 0
assert len(name_list_val) == pred_results.shape[0]
for i in range(len(name_list_val)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_val[i]))
else:
if best_list[1] < eval_list[1]:
best_list = eval_list
torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))
pred_results = torch.max(score_list_val, 1)[1]
pred_results = pred_results.cpu().numpy()
assert len(name_list_val) == pred_results.shape[0]
for i in range(len(name_list_val)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_val[i]))
return best_list
def save_val_best_sup_3d(num_classes, best_list, model, score_list_val, mask_list_val, eval_list, path_trained_model, path_seg_results, path_mask_results, model_name, format):
if num_classes == 2:
if best_list[1] < eval_list[1]:
best_list = eval_list
torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))
else:
if best_list[1] < eval_list[1]:
best_list = eval_list
torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))
return best_list
def save_val_best_2d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, name_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, palette):
if eval_list_1[1] < eval_list_2[1]:
if best_list[1] < eval_list_2[1]:
best_model = model2
best_list = eval_list_2
best_result = 'Result2'
torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))
if num_classes == 2:
score_list_val_2 = torch.softmax(score_list_val_2, dim=1)
pred_results = score_list_val_2[:, 1, ...].cpu().numpy()
pred_results[pred_results > eval_list_2[0]] = 1
pred_results[pred_results <= eval_list_2[0]] = 0
else:
pred_results = torch.max(score_list_val_2, 1)[1]
pred_results = pred_results.cpu().numpy()
assert len(name_list_val) == pred_results.shape[0]
for i in range(len(name_list_val)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_val[i]))
else:
best_model = best_model
best_list = best_list
best_result = best_result
else:
if best_list[1] < eval_list_1[1]:
best_model = model1
best_list = eval_list_1
best_result = 'Result1'
torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))
if num_classes == 2:
score_list_val_1 = torch.softmax(score_list_val_1, dim=1)
pred_results = score_list_val_1[:, 1, ...].cpu().numpy()
pred_results[pred_results > eval_list_1[0]] = 1
pred_results[pred_results <= eval_list_1[0]] = 0
else:
pred_results = torch.max(score_list_val_1, 1)[1]
pred_results = pred_results.cpu().numpy()
assert len(name_list_val) == pred_results.shape[0]
for i in range(len(name_list_val)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_val[i]))
else:
best_model = best_model
best_list = best_list
best_result = best_result
return best_list, best_model, best_result
def save_val_best_3d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, path_mask_results, format):
if eval_list_1[1] < eval_list_2[1]:
if best_list[1] < eval_list_2[1]:
best_model = model2
best_list = eval_list_2
best_result = 'Result2'
torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))
else:
best_model = best_model
best_list = best_list
best_result = best_result
else:
if best_list[1] < eval_list_1[1]:
best_model = model1
best_list = eval_list_1
best_result = 'Result1'
torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))
else:
best_model = best_model
best_list = best_list
best_result = best_result
return best_list, best_model, best_result
def draw_pred_sup(num_classes, mask_train_sup, mask_val, pred_train_sup, outputs_val, train_eval_list, val_eval_list):
mask_image_train_sup = mask_train_sup[0, :, :].data.cpu().numpy()
mask_image_val = mask_val[0, :, :].data.cpu().numpy()
if num_classes == 2:
pred_image_train_sup = pred_train_sup[0, 1, :, :].data.cpu().numpy()
pred_image_train_sup[pred_image_train_sup > train_eval_list[0]] = 1
pred_image_train_sup[pred_image_train_sup <= train_eval_list[0]] = 0
pred_image_val = outputs_val[0, 1, :, :].data.cpu().numpy()
pred_image_val[pred_image_val > val_eval_list[0]] = 1
pred_image_val[pred_image_val <= val_eval_list[0]] = 0
else:
pred_image_train_sup = torch.max(pred_train_sup, 1)[1]
pred_image_train_sup = pred_image_train_sup[0, :, :].cpu().numpy()
pred_image_val = torch.max(outputs_val, 1)[1]
pred_image_val = pred_image_val[0, :, :].cpu().numpy()
return mask_image_train_sup, pred_image_train_sup, mask_image_val, pred_image_val
def draw_pred_XNet(num_classes, mask_train, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2):
mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()
mask_image_val = mask_val[0, :, :].data.cpu().numpy()
if num_classes == 2:
pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()
pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1
pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0
pred_image_train_sup2 = pred_train_sup2[0, 1, :, :].data.cpu().numpy()
pred_image_train_sup2[pred_image_train_sup2 > train_eval_list2[0]] = 1
pred_image_train_sup2[pred_image_train_sup2 <= train_eval_list2[0]] = 0
pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()
pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1
pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0
pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()
pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1
pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0
else:
pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]
pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()
pred_image_train_sup2 = torch.max(pred_train_sup2, 1)[1]
pred_image_train_sup2 = pred_image_train_sup2[0, :, :].cpu().numpy()
pred_image_val1 = torch.max(outputs_val1, 1)[1]
pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()
pred_image_val2 = torch.max(outputs_val2, 1)[1]
pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()
return mask_image_train_sup, pred_image_train_sup1, pred_image_train_sup2, mask_image_val, pred_image_val1, pred_image_val2
def draw_pred_MT(num_classes, mask_train, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2):
mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()
mask_image_val = mask_val[0, :, :].data.cpu().numpy()
if num_classes == 2:
pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()
pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1
pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0
pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()
pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1
pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0
pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()
pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1
pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0
else:
pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]
pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()
pred_image_val1 = torch.max(outputs_val1, 1)[1]
pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()
pred_image_val2 = torch.max(outputs_val2, 1)[1]
pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()
return mask_image_train_sup, pred_image_train_sup1, mask_image_val, pred_image_val1, pred_image_val2
def print_best_sup(num_classes, best_val_list, print_num):
if num_classes == 2:
print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
print('| Best Val Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
print('| Best Val Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
else:
np.set_printoptions(precision=4, suppress=True)
print('| Best Val Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
print('| Best Val Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')
def print_best(num_classes, best_val_list, best_model, best_result, path_trained_model, print_num):
if num_classes == 2:
torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))
print('| Best Result: {}'.format(best_result).ljust(print_num, ' '), '|')
print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
print('| Best Val Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
print('| Best Val Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
else:
torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))
np.set_printoptions(precision=4, suppress=True)
print('| Best Result: {}'.format(best_result).ljust(print_num, ' '), '|')
print('| Best Val Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
print('| Best Val Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')
def print_test_eval(num_classes, score_list_test, mask_list_test, print_num):
if num_classes == 2:
eval_list = evaluate(score_list_test, mask_list_test)
print('| Test Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Test Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Test Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
else:
eval_list = evaluate_multi(score_list_test, mask_list_test)
np.set_printoptions(precision=4, suppress=True)
print('| Test Jc: {} '.format(eval_list[0]).ljust(print_num, ' '), '|')
print('| Test Dc: {} '.format(eval_list[2]).ljust(print_num, ' '), '|')
print('| Test mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
print('| Test mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')
return eval_list
def save_test_2d(num_classes, score_list_test, name_list_test, threshold, path_seg_results, palette):
if num_classes == 2:
score_list_test = torch.softmax(score_list_test, dim=1)
pred_results = score_list_test[:, 1, ...].cpu().numpy()
pred_results[pred_results > threshold] = 1
pred_results[pred_results <= threshold] = 0
assert len(name_list_test) == pred_results.shape[0]
for i in range(len(name_list_test)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_test[i]))
else:
pred_results = torch.max(score_list_test, 1)[1]
pred_results = pred_results.cpu().numpy()
assert len(name_list_test) == pred_results.shape[0]
for i in range(len(name_list_test)):
color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
color_results.putpalette(palette)
color_results.save(os.path.join(path_seg_results, name_list_test[i]))
def save_test_3d(num_classes, score_test, name_test, threshold, path_seg_results, affine):
if num_classes == 2:
score_list_test = torch.softmax(score_test, dim=0)
pred_results = score_list_test[1, ...].cpu()
pred_results[pred_results > threshold] = 1
pred_results[pred_results <= threshold] = 0
pred_results = pred_results.type(torch.uint8)
output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)
output_image.save(os.path.join(path_seg_results, name_test))
else:
pred_results = torch.max(score_test, 0)[1]
pred_results = pred_results.cpu()
pred_results = pred_results.type(torch.uint8)
output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)
output_image.save(os.path.join(path_seg_results, name_test))
================================================
FILE: config/visdom_config/__init__.py
================================================
================================================
FILE: config/visdom_config/visual_visdom.py
================================================
from visdom import Visdom
import os
def visdom_initialization_sup(env, port):
visdom = Visdom(env=env, port=port)
visdom.line([0.], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss'], width=550, height=350))
visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Loss'], width=550, height=350))
visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
return visdom
def visualization_sup(vis, epoch, train_loss, train_m_jc, val_loss, val_m_jc):
vis.line([train_loss], [epoch], win='train_loss', update='append')
vis.line([train_m_jc], [epoch], win='train_jc', update='append')
vis.line([val_loss], [epoch], win='val_loss', update='append')
vis.line([val_m_jc], [epoch], win='val_jc', update='append')
def visual_image_sup(vis, mask_train, pred_train, mask_val, pred_val):
vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
vis.heatmap(pred_train, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))
vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
vis.heatmap(pred_val, win='val_pred1', opts=dict(title='Val Pred', colormap='Viridis'))
def visdom_initialization_XNet(env, port):
visdom = Visdom(env=env, port=port)
visdom.line([[0., 0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup1', 'Train Sup2', 'Train Unsup'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc1', 'Train Jc2'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))
return visdom
def visualization_XNet(vis, epoch, train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps, train_m_jc1, train_m_jc2, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):
vis.line([[train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps]], [epoch], win='train_loss', update='append')
vis.line([[train_m_jc1, train_m_jc2]], [epoch], win='train_jc', update='append')
vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')
vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')
def visual_image_XNet(vis, mask_train, pred_train1, pred_train2, mask_val, pred_val1, pred_val2):
vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred1', colormap='Viridis'))
vis.heatmap(pred_train2, win='train_pred2', opts=dict(title='Train pred2', colormap='Viridis'))
vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))
vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))
def visdom_initialization_MT(env, port):
visdom = Visdom(env=env, port=port)
visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))
visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))
return visdom
def visualization_MT(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):
vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')
vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')
vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')
def visual_image_MT(vis, mask_train, pred_train1, mask_val, pred_val1, pred_val2):
vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))
vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))
vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))
def visdom_initialization_EM(env, port):
visdom = Visdom(env=env, port=port)
visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))
visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup'], width=550, height=350))
visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
return visdom
def visualization_EM(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_m_jc1):
vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')
vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
vis.line([val_loss_sup1], [epoch], win='val_loss', update='append')
vis.line([val_m_jc1], [epoch], win='val_jc', update='append')
def visdom_initialization_ConResNet(env, port):
visdom = Visdom(env=env, port=port)
visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Seg', 'Train Res'], width=550, height=350))
visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Seg', 'Val Res'], width=550, height=350))
visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
return visdom
def visualization_ConResNet(vis, epoch, train_loss, train_loss_seg, train_loss_res, train_m_jc1, val_loss_seg, val_loss_res, val_m_jc1):
vis.line([[train_loss, train_loss_seg, train_loss_res]], [epoch], win='train_loss', update='append')
vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
vis.line([[val_loss_seg, val_loss_res]], [epoch], win='val_loss', update='append')
vis.line([val_m_jc1], [epoch], win='val_jc', update='append')
================================================
FILE: config/warmup_config/__init__.py
================================================
================================================
FILE: config/warmup_config/warmup.py
================================================
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_last_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics, epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)
================================================
FILE: dataload/__init__.py
================================================
================================================
FILE: dataload/dataset_2d.py
================================================
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
import pywt
class dataset_itn(Dataset):
def __init__(self, data_dir, input1, augmentation_1, normalize_1, sup=True, num_images=None, **kwargs):
super(dataset_itn, self).__init__()
img_paths_1 = []
mask_paths = []
image_dir_1 = data_dir + '/' + input1
if sup:
mask_dir = data_dir + '/mask'
for image in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, image)
img_paths_1.append(image_path_1)
if sup:
mask_path = os.path.join(mask_dir, image)
mask_paths.append(mask_path)
if sup:
assert len(img_paths_1) == len(mask_paths)
if num_images is not None:
len_img_paths = len(img_paths_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
img_paths_1 = img_paths_1[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
img_paths_1 = img_paths_1 * quotient
img_paths_1 += [img_paths_1[i] for i in new_indices]
if sup:
mask_paths = mask_paths * quotient
mask_paths += [mask_paths[i] for i in new_indices]
self.img_paths_1 = img_paths_1
self.mask_paths = mask_paths
self.augmentation_1 = augmentation_1
self.normalize_1 = normalize_1
self.sup = sup
self.kwargs = kwargs
def __getitem__(self, index):
img_path_1 = self.img_paths_1[index]
img_1 = Image.open(img_path_1)
img_1 = np.array(img_1)
if self.sup:
mask_path = self.mask_paths[index]
mask = Image.open(mask_path)
mask = np.array(mask)
augment_1 = self.augmentation_1(image=img_1, mask=mask)
img_1 = augment_1['image']
mask_1 = augment_1['mask']
normalize_1 = self.normalize_1(image=img_1, mask=mask_1)
img_1 = normalize_1['image']
mask_1 = normalize_1['mask']
mask_1 = mask_1.long()
sampel = {'image': img_1, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}
else:
augment_1 = self.augmentation_1(image=img_1)
img_1 = augment_1['image']
normalize_1 = self.normalize_1(image=img_1)
img_1 = normalize_1['image']
sampel = {'image': img_1, 'ID': os.path.split(img_path_1)[1]}
return sampel
def __len__(self):
return len(self.img_paths_1)
def imagefloder_itn(data_dir, input1, data_transform_1, data_normalize_1, sup=True, num_images=None, **kwargs):
dataset = dataset_itn(data_dir=data_dir,
input1=input1,
augmentation_1=data_transform_1,
normalize_1=data_normalize_1,
sup=sup,
num_images=num_images,
**kwargs
)
return dataset
class dataset_iitnn(Dataset):
def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True,
num_images=None, **kwargs):
super(dataset_iitnn, self).__init__()
img_paths_1 = []
img_paths_2 = []
mask_paths = []
image_dir_1 = data_dir + '/' + input1
image_dir_2 = data_dir + '/' + input2
if sup:
mask_dir = data_dir + '/mask'
for image in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, image)
img_paths_1.append(image_path_1)
image_path_2 = os.path.join(image_dir_2, image)
img_paths_2.append(image_path_2)
if sup:
mask_path = os.path.join(mask_dir, image)
mask_paths.append(mask_path)
assert len(img_paths_1) == len(img_paths_2)
if sup:
assert len(img_paths_1) == len(mask_paths)
if num_images is not None:
len_img_paths = len(img_paths_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
img_paths_1 = img_paths_1[:num_images]
img_paths_2 = img_paths_2[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
img_paths_1 = img_paths_1 * quotient
img_paths_1 += [img_paths_1[i] for i in new_indices]
img_paths_2 = img_paths_2 * quotient
img_paths_2 += [img_paths_2[i] for i in new_indices]
if sup:
mask_paths = mask_paths * quotient
mask_paths += [mask_paths[i] for i in new_indices]
self.img_paths_1 = img_paths_1
self.img_paths_2 = img_paths_2
self.mask_paths = mask_paths
self.augmentation_1 = augmentation1
self.normalize_1 = normalize_1
self.normalize_2 = normalize_2
self.sup = sup
self.kwargs = kwargs
def __getitem__(self, index):
img_path_1 = self.img_paths_1[index]
img_1 = Image.open(img_path_1)
img_1 = np.array(img_1)
img_path_2 = self.img_paths_2[index]
img_2 = Image.open(img_path_2)
img_2 = np.array(img_2)
if self.sup:
mask_path = self.mask_paths[index]
mask = Image.open(mask_path)
mask = np.array(mask)
augment_1 = self.augmentation_1(image=img_1, image2=img_2, mask=mask)
img_1 = augment_1['image']
img_2 = augment_1['image2']
mask = augment_1['mask']
normalize_1 = self.normalize_1(image=img_1, mask=mask)
img_1 = normalize_1['image']
mask = normalize_1['mask']
mask = mask.long()
normalize_2 = self.normalize_2(image=img_2)
img_2 = normalize_2['image']
sampel = {'image': img_1, 'image_2': img_2, 'mask': mask, 'ID': os.path.split(mask_path)[1]}
else:
augment_1 = self.augmentation_1(image=img_1, image2=img_2)
img_1 = augment_1['image']
img_2 = augment_1['image2']
normalize_1 = self.normalize_1(image=img_1)
img_1 = normalize_1['image']
normalize_2 = self.normalize_2(image=img_2)
img_2 = normalize_2['image']
sampel = {'image': img_1, 'image_2': img_2, 'ID': os.path.split(img_path_1)[1]}
return sampel
def __len__(self):
return len(self.img_paths_1)
def imagefloder_iitnn(data_dir, input1, input2, data_transform_1, data_normalize_1, data_normalize_2, sup=True, num_images=None, **kwargs):
dataset = dataset_iitnn(data_dir=data_dir,
input1=input1,
input2=input2,
augmentation1=data_transform_1,
normalize_1=data_normalize_1,
normalize_2=data_normalize_2,
sup=sup,
num_images=num_images,
**kwargs
)
return dataset
class dataset_wds(Dataset):
def __init__(self, data_dir, augmentation1, normalize_LL, normalize_LH, normalize_HL, normalize_HH, **kwargs):
super(dataset_wds, self).__init__()
img_paths_LL = []
img_paths_LH = []
img_paths_HL = []
img_paths_HH = []
mask_paths = []
image_dir_LL = data_dir + '/LL'
image_dir_LH = data_dir + '/LH'
image_dir_HL = data_dir + '/HL'
image_dir_HH = data_dir + '/HH'
mask_dir = data_dir + '/mask'
for image in os.listdir(image_dir_LL):
image_path_LL = os.path.join(image_dir_LL, image)
img_paths_LL.append(image_path_LL)
image_path_LH = os.path.join(image_dir_LH, image)
img_paths_LH.append(image_path_LH)
image_path_HL = os.path.join(image_dir_HL, image)
img_paths_HL.append(image_path_HL)
image_path_HH = os.path.join(image_dir_HH, image)
img_paths_HH.append(image_path_HH)
mask_path = os.path.join(mask_dir, image)
mask_paths.append(mask_path)
self.img_paths_LL = img_paths_LL
self.img_paths_LH = img_paths_LH
self.img_paths_HL = img_paths_HL
self.img_paths_HH = img_paths_HH
self.mask_paths = mask_paths
self.augmentation_1 = augmentation1
self.normalize_LL = normalize_LL
self.normalize_LH = normalize_LH
self.normalize_HL = normalize_HL
self.normalize_HH = normalize_HH
self.kwargs = kwargs
def __getitem__(self, index):
img_path_LL = self.img_paths_LL[index]
img_LL = Image.open(img_path_LL)
img_LL = np.array(img_LL)
img_path_LH = self.img_paths_LH[index]
img_LH = Image.open(img_path_LH)
img_LH = np.array(img_LH)
img_path_HL = self.img_paths_HL[index]
img_HL = Image.open(img_path_HL)
img_HL = np.array(img_HL)
img_path_HH = self.img_paths_HH[index]
img_HH = Image.open(img_path_HH)
img_HH = np.array(img_HH)
mask_path = self.mask_paths[index]
mask = Image.open(mask_path)
mask = np.array(mask)
augment_1 = self.augmentation_1(image=img_LL, mask=mask, imageLH=img_LH, imageHL=img_HL, imageHH=img_HH)
img_LL = augment_1['image']
img_LH = augment_1['imageLH']
img_HL = augment_1['imageHL']
img_HH = augment_1['imageHH']
mask_1 = augment_1['mask']
normalize_LL = self.normalize_LL(image=img_LL, mask=mask_1)
img_LL = normalize_LL['image']
mask_1 = normalize_LL['mask']
mask_1 = mask_1.long()
normalize_LH = self.normalize_LH(image=img_LH)
img_LH = normalize_LH['image']
normalize_HL = self.normalize_HL(image=img_HL)
img_HL = normalize_HL['image']
normalize_HH = self.normalize_HH(image=img_HH)
img_HH = normalize_HH['image']
sampel = {'image_LL': img_LL, 'image_LH': img_LH, 'image_HL': img_HL, 'image_HH': img_HH, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}
return sampel
def __len__(self):
return len(self.img_paths_LL)
def imagefloder_wds(data_dir, data_transform_1, data_normalize_LL, data_normalize_LH, data_normalize_HL, data_normalize_HH, **kwargs):
dataset = dataset_wds(data_dir=data_dir,
augmentation1=data_transform_1,
normalize_LL=data_normalize_LL,
normalize_LH=data_normalize_LH,
normalize_HL=data_normalize_HL,
normalize_HH=data_normalize_HH,
**kwargs
)
return dataset
class dataset_aerial_lanenet(Dataset):
def __init__(self, data_dir, augmentation1, normalize_1, normalize_l1, normalize_l2, normalize_l3, normalize_l4, **kwargs):
super(dataset_aerial_lanenet, self).__init__()
img_paths = []
mask_paths = []
image_dir = data_dir + '/image'
mask_dir = data_dir + '/mask'
for image in os.listdir(image_dir):
image_path = os.path.join(image_dir, image)
img_paths.append(image_path)
mask_path = os.path.join(mask_dir, image)
mask_paths.append(mask_path)
self.img_paths = img_paths
self.mask_paths = mask_paths
self.augmentation_1 = augmentation1
self.normalize_1 = normalize_1
self.normalize_l4 = normalize_l4
self.normalize_l3 = normalize_l3
self.normalize_l2 = normalize_l2
self.normalize_l1 = normalize_l1
self.kwargs = kwargs
def __getitem__(self, index):
img_path = self.img_paths[index]
img = Image.open(img_path)
img = np.array(img)
mask_path = self.mask_paths[index]
mask = Image.open(mask_path)
mask = np.array(mask)
augment_1 = self.augmentation_1(image=img, mask=mask)
img = augment_1['image']
mask = augment_1['mask']
img_ = np.array(Image.fromarray(img).convert('L'))
_, l4, l3, l2, l1 = pywt.wavedec2(img_, 'db2', level=4)
l4 = np.array(l4).transpose(1, 2, 0)
l3 = np.array(l3).transpose(1, 2, 0)
l2 = np.array(l2).transpose(1, 2, 0)
l1 = np.array(l1).transpose(1, 2, 0)
normalize_l4 = self.normalize_l4(image=l4)
l4 = normalize_l4['image'].float()
normalize_l3 = self.normalize_l3(image=l3)
l3 = normalize_l3['image'].float()
normalize_l2 = self.normalize_l2(image=l2)
l2 = normalize_l2['image'].float()
normalize_l1 = self.normalize_l1(image=l1)
l1 = normalize_l1['image'].float()
normalize_1 = self.normalize_1(image=img, mask=mask)
img = normalize_1['image']
mask = normalize_1['mask'].long()
sampel = {'image': img, 'image_l1': l1, 'image_l2': l2, 'image_l3': l3, 'image_l4': l4, 'mask': mask, 'ID': os.path.split(mask_path)[1]}
return sampel
def __len__(self):
return len(self.img_paths)
def imagefloder_aerial_lanenet(data_dir, data_transform, data_normalize, data_normalize_l1, data_normalize_l2, data_normalize_l3, data_normalize_l4, **kwargs):
dataset = dataset_aerial_lanenet(data_dir=data_dir,
augmentation1=data_transform,
normalize_1=data_normalize,
normalize_l1=data_normalize_l1,
normalize_l2=data_normalize_l2,
normalize_l3=data_normalize_l3,
normalize_l4=data_normalize_l4,
**kwargs
)
return dataset
================================================
FILE: dataload/dataset_3d.py
================================================
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
import torchio as tio
import SimpleITK as sitk
from torchio.data import UniformSampler, LabelSampler
class dataset_it(Dataset):
def __init__(self, data_dir, input1, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
super(dataset_it, self).__init__()
self.subjects_1 = []
image_dir_1 = data_dir + '/' + input1
if sup:
mask_dir = data_dir + '/mask'
for i in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, i)
if sup:
mask_path = os.path.join(mask_dir, i)
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path), ID=i)
else:
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)
self.subjects_1.append(subject_1)
if num_images is not None:
len_img_paths = len(self.subjects_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
self.subjects_1 = self.subjects_1[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
self.subjects_1 = self.subjects_1 * quotient
self.subjects_1 += [self.subjects_1[i] for i in new_indices]
self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)
self.queue_train_set_1 = tio.Queue(
subjects_dataset=self.dataset_1,
max_length=queue_length,
samples_per_volume=samples_per_volume,
sampler=UniformSampler(patch_size),
# sampler=LabelSampler(patch_size),
num_workers=num_workers,
shuffle_subjects=shuffle_subjects,
shuffle_patches=shuffle_patches
)
class dataset_it_dtc(Dataset):
def __init__(self, data_dir, input1, num_classes, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
super(dataset_it_dtc, self).__init__()
self.subjects_1 = []
image_dir_1 = data_dir + '/' + input1
if sup:
mask_dir_1 = data_dir + '/mask'
mask_dir_2 = data_dir + '/mask_sdf1'
if num_classes == 3:
mask_dir_3 = data_dir + '/mask_sdf2'
for i in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, i)
if sup:
mask_path_1 = os.path.join(mask_dir_1, i)
mask_path_2 = os.path.join(mask_dir_2, i)
if num_classes == 3:
mask_path_3 = os.path.join(mask_dir_3, i)
subject_1 = tio.Subject(
image=tio.ScalarImage(image_path_1),
mask=tio.LabelMap(mask_path_1),
mask2=tio.LabelMap(mask_path_2),
mask3=tio.LabelMap(mask_path_3),
ID=i)
else:
subject_1 = tio.Subject(
image=tio.ScalarImage(image_path_1),
mask=tio.LabelMap(mask_path_1),
mask2=tio.LabelMap(mask_path_2),
ID=i)
else:
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)
self.subjects_1.append(subject_1)
if num_images is not None:
len_img_paths = len(self.subjects_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
self.subjects_1 = self.subjects_1[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
self.subjects_1 = self.subjects_1 * quotient
self.subjects_1 += [self.subjects_1[i] for i in new_indices]
self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)
self.queue_train_set_1 = tio.Queue(
subjects_dataset=self.dataset_1,
max_length=queue_length,
samples_per_volume=samples_per_volume,
sampler=UniformSampler(patch_size),
# sampler=LabelSampler(patch_size),
num_workers=num_workers,
shuffle_subjects=shuffle_subjects,
shuffle_patches=shuffle_patches
)
class dataset_iit(Dataset):
def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
super(dataset_iit, self).__init__()
self.subjects_1 = []
image_dir_1 = data_dir + '/' + input1
image_dir_2 = data_dir + '/' + input2
if sup:
mask_dir_1 = data_dir + '/mask'
for i in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, i)
image_path_2 = os.path.join(image_dir_2, i)
if sup:
mask_path_1 = os.path.join(mask_dir_1, i)
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), ID=i)
else:
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)
self.subjects_1.append(subject_1)
if num_images is not None:
len_img_paths = len(self.subjects_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
self.subjects_1 = self.subjects_1[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
self.subjects_1 = self.subjects_1 * quotient
self.subjects_1 += [self.subjects_1[i] for i in new_indices]
self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)
self.queue_train_set_1 = tio.Queue(
subjects_dataset=self.dataset_1,
max_length=queue_length,
samples_per_volume=samples_per_volume,
sampler=UniformSampler(patch_size),
# sampler=LabelSampler(patch_size),
num_workers=num_workers,
shuffle_subjects=shuffle_subjects,
shuffle_patches=shuffle_patches
)
class dataset_iit_conresnet(Dataset):
def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
super(dataset_iit_conresnet, self).__init__()
self.subjects_1 = []
image_dir_1 = data_dir + '/' + input1
image_dir_2 = data_dir + '/' + input2
if sup:
mask_dir_1 = data_dir + '/mask'
mask_dir_2 = data_dir + '/mask_res'
for i in os.listdir(image_dir_1):
image_path_1 = os.path.join(image_dir_1, i)
image_path_2 = os.path.join(image_dir_2, i)
if sup:
mask_path_1 = os.path.join(mask_dir_1, i)
mask_path_2 = os.path.join(mask_dir_2, i)
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), ID=i)
else:
subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)
self.subjects_1.append(subject_1)
if num_images is not None:
len_img_paths = len(self.subjects_1)
quotient = num_images // len_img_paths
remainder = num_images % len_img_paths
if num_images <= len_img_paths:
self.subjects_1 = self.subjects_1[:num_images]
else:
rand_indices = torch.randperm(len_img_paths).tolist()
new_indices = rand_indices[:remainder]
self.subjects_1 = self.subjects_1 * quotient
self.subjects_1 += [self.subjects_1[i] for i in new_indices]
self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)
self.queue_train_set_1 = tio.Queue(
subjects_dataset=self.dataset_1,
max_length=queue_length,
samples_per_volume=samples_per_volume,
sampler=UniformSampler(patch_size),
# sampler=LabelSampler(patch_size),
num_workers=num_workers,
shuffle_subjects=shuffle_subjects,
shuffle_patches=shuffle_patches
)
================================================
FILE: loss/__init__.py
================================================
================================================
FILE: loss/loss_function.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import sys
from torch.nn.modules.loss import _Loss
class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
def _aux_forward(self, output, target, **kwargs):
# *preds, target = tuple(inputs)
loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[0], target)
for i in range(1, len(output)):
aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[i], target)
loss += self.aux_weight * aux_loss
return loss
def forward(self, output, target):
# preds, target = tuple(inputs)
# inputs = tuple(list(preds) + [target])
if self.aux:
return self._aux_forward(output, target)
else:
return super(MixSoftmaxCrossEntropyLoss, self).forward(output, target)
class BinaryDiceLoss(nn.Module):
"""Dice loss of binary class
Args:
smooth: A float number to smooth loss, and avoid NaN error, default: 1
p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
predict: A tensor of shape [N, *]
target: A tensor of shape same with predict
reduction: Reduction method to apply, return mean over batch if 'mean',
return sum if 'sum', return a tensor of shape [N,] if 'none'
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""
def __init__(self, smooth=1, p=2, reduction='mean'):
super(BinaryDiceLoss, self).__init__()
self.smooth = smooth
self.p = p
self.reduction = reduction
def forward(self, predict, target, valid_mask):
assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1).float()
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1).float()
num = torch.sum(torch.mul(predict, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum((predict.pow(self.p) + target.pow(self.p)) * valid_mask, dim=1) + self.smooth
loss = 1 - num / den
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))
class DiceLoss(nn.Module):
"""Dice loss, need one hot encode input"""
def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs):
super(DiceLoss, self).__init__()
self.kwargs = kwargs
self.weight = weight
self.ignore_index = ignore_index
self.aux = aux
self.aux_weight = aux_weight
def _base_forward(self, predict, target, valid_mask):
dice = BinaryDiceLoss(**self.kwargs)
total_loss = 0
predict = F.softmax(predict, dim=1)
for i in range(target.shape[-1]):
if i != self.ignore_index:
dice_loss = dice(predict[:, i], target[..., i], valid_mask)
if self.weight is not None:
assert self.weight.shape[0] == target.shape[1], \
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss / target.shape[-1]
def _aux_forward(self, output, target, **kwargs):
# *preds, target = tuple(inputs)
valid_mask = (target != self.ignore_index).long()
target_one_hot = F.one_hot(torch.clamp_min(target, 0))
loss = self._base_forward(output[0], target_one_hot, valid_mask)
for i in range(1, len(output)):
aux_loss = self._base_forward(output[i], target_one_hot, valid_mask)
loss += self.aux_weight * aux_loss
return loss
def forward(self, output, target):
# preds, target = tuple(inputs)
# inputs = tuple(list(preds) + [target])
if self.aux:
return self._aux_forward(output, target)
else:
valid_mask = (target != self.ignore_index).long()
target_one_hot = F.one_hot(torch.clamp_min(target, 0))
return self._base_forward(output, target_one_hot, valid_mask)
def softmax_mse_loss(input_logits, target_logits, sigmoid=False):
"""Takes softmax on both sides and returns MSE loss
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
assert input_logits.size() == target_logits.size()
if sigmoid:
input_softmax = torch.sigmoid(input_logits)
target_softmax = torch.sigmoid(target_logits)
else:
input_softmax = F.softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
mse_loss = (input_softmax-target_softmax)**2
return mse_loss
def entropy_loss(p, C=2):
# p N*C*W*H*D
y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / torch.tensor(np.log(C)).cuda()
ent = torch.mean(y1)
return ent
class BCELossBoud(nn.Module):
def __init__(self, num_classes, weight=None, ignore_index=None, **kwargs):
super(BCELossBoud, self).__init__()
self.kwargs = kwargs
self.weight = weight
self.ignore_index = ignore_index
self.num_classes = num_classes
self.criterion = nn.BCEWithLogitsLoss()
def weighted_BCE_cross_entropy(self, output, target, weights = None):
if weights is not None:
assert len(weights) == 2
output = torch.clamp(output, min=1e-3, max=1-1e-3)
bce = weights[1] * (target * torch.log(output)) + weights[0] * ((1-target) * torch.log((1-output)))
else:
output = torch.clamp(output, min=1e-3, max=1 - 1e-3)
bce = target * torch.log(output) + (1-target) * torch.log((1-output))
return torch.neg(torch.mean(bce))
def forward(self, predict, target):
target_one_hot = F.one_hot(torch.clamp_min(target, 0), num_classes=self.num_classes).permute(0, 4, 1, 2, 3)
predict = torch.softmax(predict, 1)
bs, category, depth, width, heigt = target_one_hot.shape
bce_loss = []
for i in range(predict.shape[1]):
pred_i = predict[:,i]
targ_i = target_one_hot[:,i]
tt = np.log(depth * width * heigt / (target_one_hot[:, i].cpu().data.numpy().sum()+1))
bce_i = self.weighted_BCE_cross_entropy(pred_i, targ_i, weights=[1, tt])
bce_loss.append(bce_i)
bce_loss = torch.stack(bce_loss)
total_loss = bce_loss.mean()
return total_loss
class CustomKLLoss(_Loss):
'''
KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N
N is the total number of image voxels
'''
def __init__(self, *args, **kwargs):
super(CustomKLLoss, self).__init__()
def forward(self, mean, std):
return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean(
torch.log(torch.mul(std, std))) - 1
def segmentation_loss(loss='CE', aux=False, **kwargs):
if loss == 'dice' or loss == 'DICE':
seg_loss = DiceLoss(aux=aux)
elif loss == 'crossentropy' or loss == 'CE':
seg_loss = MixSoftmaxCrossEntropyLoss(aux=aux)
elif loss == 'bce':
seg_loss = nn.BCELoss(size_average=True)
elif loss == 'bcebound':
seg_loss = BCELossBoud(num_classes=kwargs['num_classes'])
else:
print('sorry, the loss you input is not supported yet')
sys.exit()
return seg_loss
# if __name__ == '__main__':
# from models import *
# criterion = segmentation_loss(loss='LOVASZ')
# # criterion = nn.CrossEntropyLoss()
#
# model = unet(1, 2)
# model.eval()
# input = torch.rand(3, 1, 128, 128)
# mask = torch.zeros(3, 128, 128).long()
#
# mask[:, 40:100, 30:60] = 1
# output = model(input)
#
# loss = criterion(output, mask)
# print(loss)
# # loss.requires_grad_(True)
# # loss.backward()
================================================
FILE: models/__init__.py
================================================
# 2d
from .networks_2d.xnet import XNet, XNet_1_1_m, XNet_1_2_m, XNet_2_1_m, XNet_3_2_m, XNet_2_3_m, XNet_3_3_m, XNet_sb
from .networks_2d.unet import unet, r2_unet, attention_unet
from .networks_2d.unet_plusplus import unet_plusplus
from .networks_2d.hrnet import hrnet18, hrnet32, hrnet48, hrnet64
from .networks_2d.swinunet import swinunet
from .networks_2d.unet_urpc import unet_urpc
from .networks_2d.unet_cct import unet_cct
from .networks_2d.resunet import res_unet
from .networks_2d.resunet_plusplus import res_unet_plusplus
from .networks_2d.u2net import u2net, u2net_small
from .networks_2d.unet_3plus import unet_3plus, unet_3plus_ds, unet_3plus_ds_cgm
from .networks_2d.wavesnet import wsegnet_vgg16_bn
from .networks_2d.mwcnn import mwcnn
from .networks_2d.aerial_lanenet import Aerial_LaneNet
from .networks_2d.wds import WDS
# 3d
from .networks_3d.unet3d import unet3d, unet3d_min
from .networks_3d.vnet import vnet
from .networks_3d.res_unet3d import res_unet3d
from .networks_3d.transbts import transbts
from .networks_3d.cotr import cotr
from .networks_3d.dmfnet import dmfnet
from .networks_3d.conresnet import conresnet
from .networks_3d.espnet3d import espnet3d
from .networks_3d.unetr import unertr
from .networks_3d.unet3d_urpc import unet3d_urpc
from .networks_3d.unet3d_cct import unet3d_cct, unet3d_cct_min
from .networks_3d.unet3d_dtc import unet3d_dtc
from .networks_3d.xnet3d import xnet3d
from .networks_3d.vnet_cct import vnet_cct
from .networks_3d.vnet_dtc import vnet_dtc
================================================
FILE: models/getnetwork.py
================================================
import sys
from models import *
import torch.nn as nn
def get_network(network, in_channels, num_classes, **kwargs):
# 2d networks
if network == 'xnet':
net = XNet(in_channels, num_classes)
elif network == 'xnet_sb':
net = XNet_sb(in_channels, num_classes)
elif network == 'xnet_1_1_m':
net = XNet_1_1_m(in_channels, num_classes)
elif network == 'xnet_1_2_m':
net = XNet_1_2_m(in_channels, num_classes)
elif network == 'xnet_2_1_m':
net = XNet_2_1_m(in_channels, num_classes)
elif network == 'xnet_3_2_m':
net = XNet_3_2_m(in_channels, num_classes)
elif network == 'xnet_2_3_m':
net = XNet_2_3_m(in_channels, num_classes)
elif network == 'xnet_3_3_m':
net = XNet_3_3_m(in_channels, num_classes)
elif network == 'unet':
net = unet(in_channels, num_classes)
elif network == 'unet_plusplus' or network == 'unet++':
net = unet_plusplus(in_channels, num_classes)
elif network == 'r2unet':
net = r2_unet(in_channels, num_classes)
elif network == 'attunet':
net = attention_unet(in_channels, num_classes)
elif network == 'hrnet18':
net = hrnet18(in_channels, num_classes)
elif network == 'hrnet48':
net = hrnet48(in_channels, num_classes)
elif network == 'resunet':
net = res_unet(in_channels, num_classes)
elif network == 'resunet++':
net = res_unet_plusplus(in_channels, num_classes)
elif network == 'u2net':
net = u2net(in_channels, num_classes)
elif network == 'u2net_s':
net = u2net_small(in_channels, num_classes)
elif network == 'unet3+':
net = unet_3plus(in_channels, num_classes)
elif network == 'unet3+_ds':
net = unet_3plus_ds(in_channels, num_classes)
elif network == 'unet3+_ds_cgm':
net = unet_3plus_ds_cgm(in_channels, num_classes)
elif network == 'swinunet':
net = swinunet(num_classes, 224) # img_size = 224
elif network == 'unet_urpc':
net = unet_urpc(in_channels, num_classes)
elif network == 'unet_cct':
net = unet_cct(in_channels, num_classes)
elif network == 'wavesnet':
net = wsegnet_vgg16_bn(in_channels, num_classes)
elif network == 'mwcnn':
net = mwcnn(in_channels, num_classes)
elif network == 'alnet':
net = Aerial_LaneNet(in_channels, num_classes)
elif network == 'wds':
net = WDS(in_channels, num_classes)
# 3d networks
elif network == 'xnet3d':
net = xnet3d(in_channels, num_classes)
elif network == 'unet3d':
net = unet3d(in_channels, num_classes)
elif network == 'unet3d_min':
net = unet3d_min(in_channels, num_classes)
elif network == 'unet3d_urpc':
net = unet3d_urpc(in_channels, num_classes)
elif network == 'unet3d_cct':
net = unet3d_cct(in_channels, num_classes)
elif network == 'unet3d_cct_min':
net = unet3d_cct_min(in_channels, num_classes)
elif network == 'unet3d_dtc':
net = unet3d_dtc(in_channels, num_classes)
elif network == 'vnet':
net = vnet(in_channels, num_classes)
elif network == 'vnet_cct':
net = vnet_cct(in_channels, num_classes)
elif network == 'vnet_dtc':
net = vnet_dtc(in_channels, num_classes)
elif network == 'resunet3d':
net = res_unet3d(in_channels, num_classes)
elif network == 'conresnet':
net = conresnet(in_channels, num_classes, img_shape=kwargs['img_shape'])
elif network == 'espnet3d':
net = espnet3d(in_channels, num_classes)
elif network == 'dmfnet':
net = dmfnet(in_channels, num_classes)
elif network == 'transbts':
net = transbts(in_channels, num_classes, img_shape=kwargs['img_shape'])
elif network == 'cotr':
net = cotr(in_channels, num_classes)
elif network == 'unertr':
net = unertr(in_channels, num_classes, img_shape=kwargs['img_shape'])
else:
print('the network you have entered is not supported yet')
sys.exit()
return net
================================================
FILE: models/networks_2d/__init__.py
================================================
================================================
FILE: models/networks_2d/aerial_lanenet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
class basic_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(basic_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.block(x)
return x
class Aerial_LaneNet(nn.Module):
def __init__(self, in_channels, num_classes):
super(Aerial_LaneNet, self).__init__()
l1, l2, l3, l4, l5 = 64, 128, 256, 512, 512
dropout = 0.2
# e1
self.conv1_1 = basic_block(in_channels, l1)
self.conv1_2 = basic_block(l1, l1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# e2
self.conv2_1 = basic_block(l1+3, l2)
self.conv2_2 = basic_block(l2, l2)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# e3
self.conv3_1 = basic_block(l2+3, l3)
self.conv3_2 = basic_block(l3, l3)
self.conv3_3 = basic_block(l3, l3)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# e4
self.conv4_1 = basic_block(l3+3, l4)
self.conv4_2 = basic_block(l4, l4)
self.conv4_3 = basic_block(l4, l4)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# e5
self.conv5_1 = basic_block(l4+3, l5)
self.conv5_2 = basic_block(l5, l5)
self.conv5_3 = basic_block(l5, l5)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# e6
self.conv6_1 = basic_block(l5, 4096)
self.drop6_1 = nn.Dropout2d(dropout)
self.conv6_2 = basic_block(4096, 4096)
self.drop6_2 = nn.Dropout2d(dropout)
self.conv6_3 = nn.ConvTranspose2d(4096, l5, kernel_size=4, stride=2, padding=1, bias=False)
# d4
self.conv4_4 = basic_block(2*l5, l5)
self.drop4_4 = nn.Dropout2d(dropout)
self.conv4_5 = nn.ConvTranspose2d(l5, l3, kernel_size=4, stride=2, padding=1, bias=False)
# d3
self.conv3_4 = basic_block(2*l3, l3)
self.drop3_4 = nn.Dropout2d(dropout)
self.conv3_5 = nn.ConvTranspose2d(l3, l2, kernel_size=4, stride=2, padding=1, bias=False)
# d2
self.conv2_4 = basic_block(2*l2, l2)
self.drop2_4 = nn.Dropout2d(dropout)
self.conv2_5 = nn.ConvTranspose2d(l2, l1, kernel_size=4, stride=2, padding=1, bias=False)
# d1
self.conv1_3 = basic_block(2*l1, l1)
self.drop1_3 = nn.Dropout2d(dropout)
self.conv1_4 = nn.ConvTranspose2d(l1, num_classes, kernel_size=4, stride=2, padding=1, bias=False)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, x_wavelet_1, x_wavelet_2, x_wavelet_3, x_wavelet_4):
x1 = self.conv1_1(x)
x1 = self.conv1_2(x1)
x1 = self.pool1(x1)
x2 = torch.cat((x1, x_wavelet_1), dim=1)
x2 = self.conv2_1(x2)
x2 = self.conv2_2(x2)
x2 = self.pool2(x2)
x3 = torch.cat((x2, x_wavelet_2), dim=1)
x3 = self.conv3_1(x3)
x3 = self.conv3_2(x3)
x3 = self.conv3_3(x3)
x3 = self.pool3(x3)
x4 = torch.cat((x3, x_wavelet_3), dim=1)
x4 = self.conv4_1(x4)
x4 = self.conv4_2(x4)
x4 = self.conv4_3(x4)
x4 = self.pool4(x4)
x5 = torch.cat((x4, x_wavelet_4), dim=1)
x5 = self.conv5_1(x5)
x5 = self.conv5_2(x5)
x5 = self.conv5_3(x5)
x5 = self.pool5(x5)
x6 = self.conv6_1(x5)
x6 = self.drop6_1(x6)
x6 = self.conv6_2(x6)
x6 = self.drop6_2(x6)
x6 = self.conv6_3(x6)
x5 = torch.cat((x6, x4), dim=1)
x5 = self.conv4_4(x5)
x5 = self.drop4_4(x5)
x5 = self.conv4_5(x5)
x4 = torch.cat((x5, x3), dim=1)
x4 = self.conv3_4(x4)
x4 = self.drop3_4(x4)
x4 = self.conv3_5(x4)
x3 = torch.cat((x4, x2), dim=1)
x3 = self.conv2_4(x3)
x3 = self.drop2_4(x3)
x3 = self.conv2_5(x3)
x2 = torch.cat((x3, x1), dim=1)
x2 = self.conv1_3(x2)
x2 = self.drop1_3(x2)
x2 = self.conv1_4(x2)
return x2
# if __name__ == '__main__':
# from loss.loss_function import segmentation_loss
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 128, 128).long()
# model = Aerial_LaneNet(1, 5)
# model.train()
# input1 = torch.rand(2, 1, 128, 128)
# input2 = torch.rand(2, 3, 64, 64)
# input3 = torch.rand(2, 3, 32, 32)
# input4 = torch.rand(2, 3, 16, 16)
# input5 = torch.rand(2, 3, 8, 8)
#
# y = model(input1, input2, input3, input4, input5)
# loss_train = criterion(y, mask)
# loss_train.backward()
# # print(output)
# print(y.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: models/networks_2d/hrnet.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import functools
import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
from torch.nn import init
try:
from .sync_bn.inplace_abn.bn import InPlaceABNSync
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
except:
BatchNorm2d = nn.BatchNorm2d
BN_MOMENTUM = 0.01
logger = logging.getLogger(__name__)
model_urls = {
'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth',
}
import sys
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
def load_url(url, model_dir='./pretrained', map_location=None):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
return torch.load(cached_file, map_location=map_location)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=False)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm2d(num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False),
BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode='bilinear',align_corners=False)
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {
'BASIC': BasicBlock,
'BOTTLENECK': Bottleneck
}
class HighResolutionNet(nn.Module):
def __init__(self, in_channels, extra, num_classes,**kwargs):
super(HighResolutionNet, self).__init__()
# stem net
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=False)
self.stage1_cfg = extra['STAGE1']
num_channels = self.stage1_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage1_cfg['BLOCK']]
num_blocks = self.stage1_cfg['NUM_BLOCKS']
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
stage1_out_channel = block.expansion * num_channels
self.stage2_cfg = extra['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[stage1_out_channel], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = extra['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = extra['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
last_inp_channels = int(np.sum(pre_stage_channels))
self.last_layer = nn.Sequential(
nn.Conv2d(
in_channels=last_inp_channels,
out_channels=last_inp_channels,
kernel_size=1,
stride=1,
padding=0),
BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=False),
nn.Conv2d(
in_channels=last_inp_channels,
out_channels=num_classes,
kernel_size=extra['FINAL_CONV_KERNEL'],
stride=1,
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)
)
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
BatchNorm2d(
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
size =x.shape[2:]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
if i < self.stage2_cfg['NUM_BRANCHES']:
x_list.append(self.transition2[i](y_list[i]))
else:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
if i < self.stage3_cfg['NUM_BRANCHES']:
x_list.append(self.transition3[i](y_list[i]))
else:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
x = torch.cat([x[0], x1, x2, x3], 1)
x = self.last_layer(x)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
# outputs = []
# outputs.append(x)
# return outputs
return x
def init_weights(self, pretrained='', ):
logger.info('=> init weights from normal distribution')
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
elif isinstance(m, InPlaceABNSync):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()}
for k, _ in pretrained_dict.items():
logger.info(
'=> loading {} pretrained model {}'.format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
# class HRNet(nn.Module):
# def __init__(self, in_channels, extra, num_classes, **kwargs):
# super(HRNet, self).__init__()
# self.branch1 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)
# self.branch2 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)
#
# def forward(self, data, step=1):
# if not self.training:
# pred1 = self.branch1(data)
# return pred1
#
# if step == 1:
# return self.branch1(data)
# elif step == 2:
# return self.branch2(data)
extra_18 = {
'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (18, 36), 'FUSE_METHOD': 'SUM'},
'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (18, 36, 72), 'FUSE_METHOD': 'SUM'},
'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (18, 36, 72, 144), 'FUSE_METHOD': 'SUM'},
'FINAL_CONV_KERNEL': 1
}
extra_32 = {
'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (32, 64), 'FUSE_METHOD': 'SUM'},
'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (32, 64, 128), 'FUSE_METHOD': 'SUM'},
'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (32, 64, 128, 256), 'FUSE_METHOD': 'SUM'},
'FINAL_CONV_KERNEL': 1
}
extra_48 = {
'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'},
'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'},
'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'},
'FINAL_CONV_KERNEL': 1
}
extra_64 = {
'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (64, 128), 'FUSE_METHOD': 'SUM'},
'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (64, 128, 256), 'FUSE_METHOD': 'SUM'},
'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (64, 128, 256, 512), 'FUSE_METHOD': 'SUM'},
'FINAL_CONV_KERNEL': 1
}
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
# def hrnet18(in_channels, num_classes):
# model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)
# return model
#
# def hrnet32(in_channels, num_classes):
# model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)
# return model
#
# def hrnet48(in_channels, num_classes):
# model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)
# return model
#
# def hrnet64(in_channels, num_classes):
# model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)
# return model
def hrnet18(in_channels, num_classes):
model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)
init_weights(model, 'kaiming')
return model
def hrnet32(in_channels, num_classes):
model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)
init_weights(model, 'kaiming')
return model
def hrnet48(in_channels, num_classes):
model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)
init_weights(model, 'kaiming')
return model
def hrnet64(in_channels, num_classes):
model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = hrnet48(1,10)
# total = sum([param.nelement() for param in model.parameters()])
# from thop import profile,clever_format
#
# input = torch.randn(1, 1, 128, 128)
# flops, params = profile(model, inputs=(input, ))
# macs, params = clever_format([flops, params], "%.3f")
# print(macs)
# print(params)
# print(total)
# model.eval()
# input = torch.rand(1,1,256,256)
# output = model(input)
# output = output[0].data.cpu().numpy()
# print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/mwcnn.py
================================================
import torch
import torch.nn as nn
import scipy.io as sio
import math
import torch.nn.functional as F
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2) + dilation - 1, bias=bias, dilation=dilation)
def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias, groups=groups)
# def shuffle_channel()
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def pixel_down_shuffle(x, downsacale_factor):
batchsize, num_channels, height, width = x.size()
out_height = height // downsacale_factor
out_width = width // downsacale_factor
input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width,
downsacale_factor)
num_channels *= downsacale_factor ** 2
unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
return unshuffle_out.view(batchsize, num_channels, out_height, out_width)
def sp_init(x):
x01 = x[:, :, 0::2, :]
x02 = x[:, :, 1::2, :]
x_LL = x01[:, :, :, 0::2]
x_HL = x02[:, :, :, 0::2]
x_LH = x01[:, :, :, 1::2]
x_HH = x02[:, :, :, 1::2]
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
# print([in_batch, in_channel, in_height, in_width])
out_batch, out_channel, out_height, out_width = in_batch, int(
in_channel / (r ** 2)), r * in_height, r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
# h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
class Channel_Shuffle(nn.Module):
def __init__(self, conv_groups):
super(Channel_Shuffle, self).__init__()
self.conv_groups = conv_groups
self.requires_grad = False
def forward(self, x):
return channel_shuffle(x, self.conv_groups)
class SP(nn.Module):
def __init__(self):
super(SP, self).__init__()
self.requires_grad = False
def forward(self, x):
return sp_init(x)
class Pixel_Down_Shuffle(nn.Module):
def __init__(self):
super(Pixel_Down_Shuffle, self).__init__()
self.requires_grad = False
def forward(self, x):
return pixel_down_shuffle(x, 2)
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return dwt_init(x)
class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return iwt_init(x)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
if sign == -1:
self.create_graph = False
self.volatile = True
class MeanShift2(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift2, self).__init__(4, 4, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(4).view(4, 4, 1, 1)
self.weight.data.div_(std.view(4, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
if sign == -1:
self.volatile = True
class BasicBlock(nn.Sequential):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=False, act=nn.ReLU(True)):
m = [nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class BBlock(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(BBlock, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x).mul(self.res_scale)
return x
class DBlock_com(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_com, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class DBlock_inv(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_inv, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class DBlock_com1(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_com1, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class DBlock_inv1(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_inv1, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class DBlock_com2(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_com2, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class DBlock_inv2(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DBlock_inv2, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x)
return x
class ShuffleBlock(nn.Module):
def __init__(
self, conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1, conv_groups=1):
super(ShuffleBlock, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
m.append(Channel_Shuffle(conv_groups))
if bn: m.append(nn.BatchNorm2d(out_channels))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x).mul(self.res_scale)
return x
class DWBlock(nn.Module):
def __init__(
self, conv, conv1, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(DWBlock, self).__init__()
m = []
m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(out_channels))
m.append(act)
m.append(conv1(in_channels, out_channels, 1, bias=bias))
if bn: m.append(nn.BatchNorm2d(out_channels))
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
x = self.body(x).mul(self.res_scale)
return x
class ResBlock(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feat))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Block(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(Block, self).__init__()
m = []
for i in range(4):
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feat))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
# res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feat, 4 * n_feat, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
elif scale == 3:
m.append(conv(n_feat, 9 * n_feat, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
class MWCNN(nn.Module):
def __init__(self, in_channels, num_classes, conv=default_conv):
super(MWCNN, self).__init__()
kernel_size = 3
self.scale_idx = 0
n_feats = 64
act = nn.ReLU(True)
self.DWT = DWT()
self.IWT = IWT()
n = 1
m_head = [BBlock(conv, in_channels, n_feats, kernel_size, act=act)]
d_l0 = []
d_l0.append(DBlock_com1(conv, n_feats, n_feats, kernel_size, act=act, bn=False))
d_l1 = [BBlock(conv, n_feats * 4, n_feats * 2, kernel_size, act=act, bn=False)]
d_l1.append(DBlock_com1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False))
d_l2 = []
d_l2.append(BBlock(conv, n_feats * 8, n_feats * 4, kernel_size, act=act, bn=False))
d_l2.append(DBlock_com1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False))
pro_l3 = []
pro_l3.append(BBlock(conv, n_feats * 16, n_feats * 8, kernel_size, act=act, bn=False))
pro_l3.append(DBlock_com(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))
pro_l3.append(DBlock_inv(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))
pro_l3.append(BBlock(conv, n_feats * 8, n_feats * 16, kernel_size, act=act, bn=False))
i_l2 = [DBlock_inv1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False)]
i_l2.append(BBlock(conv, n_feats * 4, n_feats * 8, kernel_size, act=act, bn=False))
i_l1 = [DBlock_inv1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False)]
i_l1.append(BBlock(conv, n_feats * 2, n_feats * 4, kernel_size, act=act, bn=False))
i_l0 = [DBlock_inv1(conv, n_feats, n_feats, kernel_size, act=act, bn=False)]
m_tail = [conv(n_feats, num_classes, kernel_size)]
self.head = nn.Sequential(*m_head)
self.d_l2 = nn.Sequential(*d_l2)
self.d_l1 = nn.Sequential(*d_l1)
self.d_l0 = nn.Sequential(*d_l0)
self.pro_l3 = nn.Sequential(*pro_l3)
self.i_l2 = nn.Sequential(*i_l2)
self.i_l1 = nn.Sequential(*i_l1)
self.i_l0 = nn.Sequential(*i_l0)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x0 = self.d_l0(self.head(x))
x1 = self.d_l1(self.DWT(x0))
x2 = self.d_l2(self.DWT(x1))
x_ = self.IWT(self.pro_l3(self.DWT(x2))) + x2
x_ = self.IWT(self.i_l2(x_)) + x1
x_ = self.IWT(self.i_l1(x_)) + x0
x_ = self.tail(self.i_l0(x_))
return x_
def set_scale(self, scale_idx):
self.scale_idx = scale_idx
def mwcnn(in_channels, num_classes):
model = MWCNN(in_channels, num_classes)
return model
# if __name__ == '__main__':
# from loss.loss_function import segmentation_loss
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 128, 128).long()
# model = mwcnn(1, 2)
# model.train()
# input1 = torch.rand(2, 1, 128, 128)
# y = model(input1)
# loss_train = criterion(y, mask)
# loss_train.backward()
# # print(output)
# print(y.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: models/networks_2d/resunet.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
def forward(self, x):
return self.conv_block(x) + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class ResUnet(nn.Module):
def __init__(self, in_channels, num_classes, filters=[64, 128, 256, 512]):
super(ResUnet, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)
)
self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
self.bridge = ResidualConv(filters[2], filters[3], 2, 1)
self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)
self.output_layer = nn.Conv2d(filters[0], num_classes, 1, 1)
def forward(self, x):
# Encode
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.residual_conv_1(x1)
x3 = self.residual_conv_2(x2)
# Bridge
x4 = self.bridge(x3)
# Decode
x4 = self.upsample_1(x4)
x5 = torch.cat([x4, x3], dim=1)
x6 = self.up_residual_conv1(x5)
x6 = self.upsample_2(x6)
x7 = torch.cat([x6, x2], dim=1)
x8 = self.up_residual_conv2(x7)
x8 = self.upsample_3(x8)
x9 = torch.cat([x8, x1], dim=1)
x10 = self.up_residual_conv3(x9)
output = self.output_layer(x10)
return output
def res_unet(in_channels, num_classes):
model = ResUnet(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = res_unet(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/resunet_plusplus.py
================================================
import torch.nn as nn
import torch
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
def forward(self, x):
return self.conv_block(x) + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class Squeeze_Excite_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_Excite_Block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ASPP(nn.Module):
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
super(ASPP, self).__init__()
self.aspp_block1 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block2 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block3 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
self._init_weights()
def forward(self, x):
x1 = self.aspp_block1(x)
x2 = self.aspp_block2(x)
x3 = self.aspp_block3(x)
out = torch.cat([x1, x2, x3], dim=1)
return self.output(out)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Upsample_(nn.Module):
def __init__(self, scale=2):
super(Upsample_, self).__init__()
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale, align_corners=True)
def forward(self, x):
return self.upsample(x)
class AttentionBlock(nn.Module):
def __init__(self, input_encoder, input_decoder, output_dim):
super(AttentionBlock, self).__init__()
self.conv_encoder = nn.Sequential(
nn.BatchNorm2d(input_encoder),
nn.ReLU(),
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
nn.MaxPool2d(2, 2),
)
self.conv_decoder = nn.Sequential(
nn.BatchNorm2d(input_decoder),
nn.ReLU(),
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
)
self.conv_attn = nn.Sequential(
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, 1, 1),
)
def forward(self, x1, x2):
out = self.conv_encoder(x1) + self.conv_decoder(x2)
out = self.conv_attn(out)
return out * x2
class ResUnetPlusPlus(nn.Module):
def __init__(self, in_channels, num_classes, filters=[32, 64, 128, 256, 512]):
super(ResUnetPlusPlus, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)
)
self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])
self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)
self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])
self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)
self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])
self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)
self.aspp_bridge = ASPP(filters[3], filters[4])
self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
self.upsample1 = Upsample_(2)
self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)
self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
self.upsample2 = Upsample_(2)
self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)
self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
self.upsample3 = Upsample_(2)
self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)
self.aspp_out = ASPP(filters[1], filters[0])
self.output_layer = nn.Conv2d(filters[0], num_classes, 1)
def forward(self, x):
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.squeeze_excite1(x1)
x2 = self.residual_conv1(x2)
x3 = self.squeeze_excite2(x2)
x3 = self.residual_conv2(x3)
x4 = self.squeeze_excite3(x3)
x4 = self.residual_conv3(x4)
x5 = self.aspp_bridge(x4)
x6 = self.attn1(x3, x5)
x6 = self.upsample1(x6)
x6 = torch.cat([x6, x3], dim=1)
x6 = self.up_residual_conv1(x6)
x7 = self.attn2(x2, x6)
x7 = self.upsample2(x7)
x7 = torch.cat([x7, x2], dim=1)
x7 = self.up_residual_conv2(x7)
x8 = self.attn3(x1, x7)
x8 = self.upsample3(x8)
x8 = torch.cat([x8, x1], dim=1)
x8 = self.up_residual_conv3(x8)
x9 = self.aspp_out(x8)
out = self.output_layer(x9)
return out
def res_unet_plusplus(in_channels, num_classes):
model = ResUnetPlusPlus(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = res_unet_plusplus(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/swinunet.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
import torch.utils.checkpoint as checkpoint
logger = logging.getLogger(__name__)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
).view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size,
window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - \
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - \
1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index",
relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).contiguous().reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
# nW*B, window_size, window_size, C
x_windows = window_partition(shifted_x, self.window_size)
# nW*B, window_size*window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
# nW*B, window_size*window_size, C
attn_windows = self.attn(x_windows, mask=self.attn_mask)
# merge windows
attn_windows = attn_windows.view(-1,
self.window_size, self.window_size, C)
shifted_x = window_reverse(
attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(
self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class PatchExpand(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.expand = nn.Linear(
dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',
p1=2, p2=2, c=C//4)
x = x.view(B, -1, C//4)
x = self.norm(x)
return x
class FinalPatchExpand_X4(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.dim_scale = dim_scale
self.expand = nn.Linear(dim, 16*dim, bias=False)
self.output_dim = dim
self.norm = norm_layer(self.output_dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',
p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
x = x.view(B, -1, self.output_dim)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (
i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class BasicLayer_up(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (
i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if upsample is not None:
self.upsample = PatchExpand(
input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
else:
self.upsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] //
patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * \
(self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformerSys(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, final_upsample="expand_first", **kwargs):
super().__init__()
print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
depths_decoder, drop_path_rate, num_classes))
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.num_features_up = int(embed_dim * 2)
self.mlp_ratio = mlp_ratio
self.final_upsample = final_upsample
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths))] # stochastic depth decay rule
# build encoder and bottleneck layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(
depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
# build decoder layers
self.layers_up = nn.ModuleList()
self.concat_back_dim = nn.ModuleList()
for i_layer in range(self.num_layers):
concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
if i_layer == 0:
layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
else:
layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
depth=depths[(
self.num_layers-1-i_layer)],
num_heads=num_heads[(
self.num_layers-1-i_layer)],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:(
self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
norm_layer=norm_layer,
upsample=PatchExpand if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers_up.append(layer_up)
self.concat_back_dim.append(concat_linear)
self.norm = norm_layer(self.num_features)
self.norm_up = norm_layer(self.embed_dim)
if self.final_upsample == "expand_first":
print("---final upsample expand_first---")
self.up = FinalPatchExpand_X4(input_resolution=(
img_size//patch_size, img_size//patch_size), dim_scale=4, dim=embed_dim)
self.output = nn.Conv2d(
in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
#Encoder and Bottleneck
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
x_downsample = []
for layer in self.layers:
x_downsample.append(x)
x = layer(x)
x = self.norm(x) # B L C
return x, x_downsample
# Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
x = torch.cat([x, x_downsample[3-inx]], -1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
x = self.norm_up(x) # B L C
return x
def up_x4(self, x):
H, W = self.patches_resolution
B, L, C = x.shape
assert L == H*W, "input features has wrong size"
if self.final_upsample == "expand_first":
x = self.up(x)
x = x.view(B, 4*H, 4*W, -1)
x = x.permute(0, 3, 1, 2).contiguous() # B,C,H,W
x = self.output(x)
return x
def forward(self, x):
x, x_downsample = self.forward_features(x)
x = self.forward_up_features(x, x_downsample)
x = self.up_x4(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * \
self.patches_resolution[0] * \
self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
class SwinUnet(nn.Module):
def __init__(self, num_classes, img_size, zero_head=False, vis=False):
super(SwinUnet, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.swin_unet = SwinTransformerSys(img_size=img_size,
patch_size=4,
num_classes=num_classes,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=False,
drop_rate=0.0,
drop_path_rate=0.1,
ape=False,
patch_norm=True,
use_checkpoint=False)
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
logits = self.swin_unet(x)
return logits
def load_from(self, config):
pretrained_path = config.MODEL.PRETRAIN_CKPT
if pretrained_path is not None:
print("pretrained_path:{}".format(pretrained_path))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(pretrained_path, map_location=device)
if "model" not in pretrained_dict:
print("---start load pretrained modle by splitting---")
pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
for k in list(pretrained_dict.keys()):
if "output" in k:
print("delete key:{}".format(k))
del pretrained_dict[k]
msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
# print(msg)
return
pretrained_dict = pretrained_dict['model']
print("---start load pretrained modle of swin encoder---")
model_dict = self.swin_unet.state_dict()
full_dict = copy.deepcopy(pretrained_dict)
for k, v in pretrained_dict.items():
if "layers." in k:
current_layer_num = 3-int(k[7:8])
current_k = "layers_up." + str(current_layer_num) + k[8:]
full_dict.update({current_k:v})
for k in list(full_dict.keys()):
if k in model_dict:
if full_dict[k].shape != model_dict[k].shape:
print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))
del full_dict[k]
msg = self.swin_unet.load_state_dict(full_dict, strict=False)
# print(msg)
else:
print("none pretrain")
def swinunet(num_classes, img_size):
model = SwinUnet(num_classes, img_size=img_size)
return model
if __name__ == '__main__':
model = swinunet(10, 224)
model.eval()
input = torch.rand(2,1,224,224)
output = model(input)
output = output.data.cpu().numpy()
# print(output)
print(output.shape)
================================================
FILE: models/networks_2d/u2net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.interpolate(src,size=tar.shape[2:],mode='bilinear', align_corners=True)
return src
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NET,self).__init__()
self.stage1 = RSU7(in_ch,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return d0, d1, d2, d3, d4, d5, d6
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NETP,self).__init__()
self.stage1 = RSU7(in_ch,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,16,64)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(64,16,64)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(64,16,64)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(64,16,64)
# decoder
self.stage5d = RSU4F(128,16,64)
self.stage4d = RSU4(128,16,64)
self.stage3d = RSU5(128,16,64)
self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return d0, d1, d2, d3, d4, d5, d6
def u2net(in_channels, num_classes):
model = U2NET(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def u2net_small(in_channels, num_classes):
model = U2NETP(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = u2net(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output = model(input)
# output = output[1].data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/unet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class Recurrent_block(nn.Module):
def __init__(self, ch_out, t=2):
super(Recurrent_block, self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
for i in range(self.t):
if i == 0:
x1 = self.conv(x)
x1 = self.conv(x + x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self, ch_in, ch_out, t=2):
super(RRCNN_block, self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out, t=t),
Recurrent_block(ch_out, t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x + x1
class single_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(single_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class Attention_block(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(Attention_block, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class U_Net(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super(U_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoding path
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
# outputs = []
# outputs.append(d1)
# return outputs
return d1
class R2U_Net(nn.Module):
def __init__(self, in_channels=3, num_classes=1, t=2):
super(R2U_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Upsample = nn.Upsample(scale_factor=2)
self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t)
self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)
self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)
self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)
self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)
self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoding path
x1 = self.RRCNN1(x)
x2 = self.Maxpool(x1)
x2 = self.RRCNN2(x2)
x3 = self.Maxpool(x2)
x3 = self.RRCNN3(x3)
x4 = self.Maxpool(x3)
x4 = self.RRCNN4(x4)
x5 = self.Maxpool(x4)
x5 = self.RRCNN5(x5)
# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_RRCNN5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_RRCNN4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_RRCNN3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_RRCNN2(d2)
d1 = self.Conv_1x1(d2)
# outputs = []
# outputs.append(d1)
# return outputs
return d1
class AttU_Net(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super(AttU_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoding path
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
# decoding + concat path
d5 = self.Up5(x5)
x4 = self.Att5(g=d5, x=x4)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4, x=x3)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3, x=x2)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2, x=x1)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
# outputs = []
# outputs.append(d1)
# return outputs
return d1
class R2AttU_Net(nn.Module):
def __init__(self, in_channels=3, num_classes=1, t=2):
super(R2AttU_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Upsample = nn.Upsample(scale_factor=2)
self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t)
self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)
self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)
self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)
self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)
self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoding path
x1 = self.RRCNN1(x)
x2 = self.Maxpool(x1)
x2 = self.RRCNN2(x2)
x3 = self.Maxpool(x2)
x3 = self.RRCNN3(x3)
x4 = self.Maxpool(x3)
x4 = self.RRCNN4(x4)
x5 = self.Maxpool(x4)
x5 = self.RRCNN5(x5)
# decoding + concat path
d5 = self.Up5(x5)
x4 = self.Att5(g=d5, x=x4)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_RRCNN5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4, x=x3)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_RRCNN4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3, x=x2)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_RRCNN3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2, x=x1)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_RRCNN2(d2)
d1 = self.Conv_1x1(d2)
# outputs = []
# outputs.append(d1)
# return outputs
return d1
def unet(in_channels, num_classes):
model = U_Net(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def r2_unet(in_channels, num_classes):
model = R2U_Net(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def attention_unet(in_channels, num_classes):
model = AttU_Net(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def r2_attention_unet(in_channels, num_classes):
model = R2AttU_Net(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = U_Net(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output = model(input)
# output = output[0].data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/unet_3plus.py
================================================
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import numpy as np
def weights_init_normal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.xavier_normal_(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal_(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.orthogonal_(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal_(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
#print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
super(unetConv2, self).__init__()
self.n = n
self.ks = ks
self.stride = stride
self.padding = padding
s = stride
p = padding
if is_batchnorm:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True), )
setattr(self, 'conv%d' % i, conv)
in_size = out_size
else:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.ReLU(inplace=True), )
setattr(self, 'conv%d' % i, conv)
in_size = out_size
# initialise the blocks
for m in self.children():
init_weights(m, init_type='kaiming')
def forward(self, inputs):
x = inputs
for i in range(1, self.n + 1):
conv = getattr(self, 'conv%d' % i)
x = conv(x)
return x
'''
UNet 3+
'''
class UNet_3Plus(nn.Module):
def __init__(self, in_channels, num_classes):
super(UNet_3Plus, self).__init__()
feature_scale = 4
is_deconv = True
is_batchnorm = True
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [16, 32, 64, 128, 256]
## -------------Encoder--------------
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
# h3->80*80, hd4->40*40, Pooling 2 times
self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_PT_hd4_relu = nn.ReLU(inplace=True)
# h4->40*40, hd4->40*40, Concatenation
self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->40*40, Upsample 2 times
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu4d_1 = nn.ReLU(inplace=True)
'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd3_relu = nn.ReLU(inplace=True)
# h2->160*160, hd3->80*80, Pooling 2 times
self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd3_relu = nn.ReLU(inplace=True)
# h3->80*80, hd3->80*80, Concatenation
self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu3d_1 = nn.ReLU(inplace=True)
'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd2_relu = nn.ReLU(inplace=True)
# h2->160*160, hd2->160*160, Concatenation
self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu2d_1 = nn.ReLU(inplace=True)
'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)
# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu1d_1 = nn.ReLU(inplace=True)
# output
self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def forward(self, inputs):
## -------------Encoder-------------
h1 = self.conv1(inputs) # h1->320*320*64
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels
h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels
h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels
d1 = self.outconv1(hd1) # d1->320*320*n_classes
return d1
'''
UNet 3+ with deep supervision
'''
class UNet_3Plus_DeepSup(nn.Module):
def __init__(self, in_channels=3, num_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):
super(UNet_3Plus_DeepSup, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [32, 64, 128, 256, 512]
## -------------Encoder--------------
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
# h3->80*80, hd4->40*40, Pooling 2 times
self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_PT_hd4_relu = nn.ReLU(inplace=True)
# h4->40*40, hd4->40*40, Concatenation
self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->40*40, Upsample 2 times
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu4d_1 = nn.ReLU(inplace=True)
'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd3_relu = nn.ReLU(inplace=True)
# h2->160*160, hd3->80*80, Pooling 2 times
self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd3_relu = nn.ReLU(inplace=True)
# h3->80*80, hd3->80*80, Concatenation
self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu3d_1 = nn.ReLU(inplace=True)
'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd2_relu = nn.ReLU(inplace=True)
# h2->160*160, hd2->160*160, Concatenation
self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu2d_1 = nn.ReLU(inplace=True)
'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)
# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu1d_1 = nn.ReLU(inplace=True)
# -------------Bilinear Upsampling--------------
self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) ###
self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# DeepSup
self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)
self.outconv2 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)
self.outconv3 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)
self.outconv4 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)
self.outconv5 = nn.Conv2d(filters[4], num_classes, 3, padding=1)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def forward(self, inputs):
## -------------Encoder-------------
h1 = self.conv1(inputs) # h1->320*320*64
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels
h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels
h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels
d5 = self.outconv5(hd5)
d5 = self.upscore5(d5) # 16->256
d4 = self.outconv4(hd4)
d4 = self.upscore4(d4) # 32->256
d3 = self.outconv3(hd3)
d3 = self.upscore3(d3) # 64->256
d2 = self.outconv2(hd2)
d2 = self.upscore2(d2) # 128->256
d1 = self.outconv1(hd1) # 256
return d1, d2, d3, d4, d5
'''
UNet 3+ with deep supervision and class-guided module
'''
class UNet_3Plus_DeepSup_CGM(nn.Module):
def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):
super(UNet_3Plus_DeepSup_CGM, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [64, 128, 256, 512, 1024]
## -------------Encoder--------------
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
# h3->80*80, hd4->40*40, Pooling 2 times
self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_PT_hd4_relu = nn.ReLU(inplace=True)
# h4->40*40, hd4->40*40, Concatenation
self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->40*40, Upsample 2 times
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu4d_1 = nn.ReLU(inplace=True)
'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd3_relu = nn.ReLU(inplace=True)
# h2->160*160, hd3->80*80, Pooling 2 times
self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd3_relu = nn.ReLU(inplace=True)
# h3->80*80, hd3->80*80, Concatenation
self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu3d_1 = nn.ReLU(inplace=True)
'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd2_relu = nn.ReLU(inplace=True)
# h2->160*160, hd2->160*160, Concatenation
self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu2d_1 = nn.ReLU(inplace=True)
'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)
# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu1d_1 = nn.ReLU(inplace=True)
# -------------Bilinear Upsampling--------------
self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) ###
self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# DeepSup
self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1)
self.cls = nn.Sequential(
nn.Dropout(p=0.5),
nn.Conv2d(filters[4], 2, 1),
nn.AdaptiveMaxPool2d(1),
nn.Sigmoid())
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def dotProduct(self, seg, cls):
B, N, H, W = seg.size()
seg = seg.view(B, N, H * W)
final = torch.einsum("ijk,ij->ijk", [seg, cls])
final = final.view(B, N, H, W)
return final
def forward(self, inputs):
## -------------Encoder-------------
h1 = self.conv1(inputs) # h1->320*320*64
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024
# -------------Classification-------------
cls_branch = self.cls(hd5).squeeze(3).squeeze(2) # (B,N,1,1)->(B,N)
cls_branch_max = cls_branch.argmax(dim=1)
cls_branch_max = cls_branch_max[:, np.newaxis].float()
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels
h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels
h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels
d5 = self.outconv5(hd5)
d5 = self.upscore5(d5) # 16->256
d4 = self.outconv4(hd4)
d4 = self.upscore4(d4) # 32->256
d3 = self.outconv3(hd3)
d3 = self.upscore3(d3) # 64->256
d2 = self.outconv2(hd2)
d2 = self.upscore2(d2) # 128->256
d1 = self.outconv1(hd1) # 256
d1 = self.dotProduct(d1, cls_branch_max)
d2 = self.dotProduct(d2, cls_branch_max)
d3 = self.dotProduct(d3, cls_branch_max)
d4 = self.dotProduct(d4, cls_branch_max)
d5 = self.dotProduct(d5, cls_branch_max)
return d1, d2, d3, d4, d5,
def unet_3plus(in_channels, num_classes):
model = UNet_3Plus(in_channels, num_classes)
return model
def unet_3plus_ds(in_channels, num_classes):
model = UNet_3Plus_DeepSup(in_channels, num_classes)
return model
def unet_3plus_ds_cgm(in_channels, num_classes):
model = UNet_3Plus_DeepSup_CGM(in_channels, num_classes)
return model
if __name__ == '__main__':
model = unet_3plus_ds_cgm(1,10)
model.eval()
input = torch.rand(2,1,128,128)
output = model(input)
output = output[0].data.cpu().numpy()
# print(output)
print(output.shape)
================================================
FILE: models/networks_2d/unet_cct.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""
def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
def forward(self, x):
return self.conv_conv(x)
class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""
def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, dropout_p)
)
def forward(self, x):
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])
def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
return [x0, x1, x2, x3, x4]
class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)
self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)
def forward(self, feature):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
output = self.out_conv(x)
return output
def Dropout(x, p=0.3):
x = torch.nn.functional.dropout(x, p)
return x
def FeatureDropout(x):
attention = torch.mean(x, dim=1, keepdim=True)
max_val, _ = torch.max(attention.view(
x.size(0), -1), dim=1, keepdim=True)
threshold = max_val * np.random.uniform(0.7, 0.9)
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
drop_mask = (attention < threshold).float()
x = x.mul(drop_mask)
return x
class FeatureNoise(nn.Module):
def __init__(self, uniform_range=0.3):
super(FeatureNoise, self).__init__()
self.uni_dist = Uniform(-uniform_range, uniform_range)
def feature_based_noise(self, x):
noise_vector = self.uni_dist.sample(
x.shape[1:]).to(x.device).unsqueeze(0)
x_noise = x.mul(noise_vector) + x
return x_noise
def forward(self, x):
x = self.feature_based_noise(x)
return x
class UNet_CCT(nn.Module):
def __init__(self, in_chns, class_num):
super(UNet_CCT, self).__init__()
params = {'in_chns': in_chns,
'feature_chns': [16, 32, 64, 128, 256],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}
self.encoder = Encoder(params)
self.main_decoder = Decoder(params)
self.aux_decoder1 = Decoder(params)
self.aux_decoder2 = Decoder(params)
self.aux_decoder3 = Decoder(params)
def forward(self, x):
feature = self.encoder(x)
main_seg = self.main_decoder(feature)
aux1_feature = [FeatureNoise()(i) for i in feature]
aux_seg1 = self.aux_decoder1(aux1_feature)
aux2_feature = [Dropout(i) for i in feature]
aux_seg2 = self.aux_decoder2(aux2_feature)
aux3_feature = [FeatureDropout(i) for i in feature]
aux_seg3 = self.aux_decoder3(aux3_feature)
return main_seg, aux_seg1, aux_seg2, aux_seg3
def unet_cct(in_channels, num_classes):
model = UNet_CCT(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet_cct(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output, output1, output2, output3 = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/unet_plusplus.py
================================================
import torch
from torch import nn
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class VGGBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
class NestedUNet(nn.Module):
def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
outputs = []
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
outputs.append(output1)
outputs.append(output2)
outputs.append(output3)
outputs.append(output4)
return outputs
else:
output = self.final(x0_4)
# outputs.append(output)
# return outputs
return output
def unet_plusplus(in_channels, num_classes):
model = NestedUNet(num_classes=num_classes, input_channels=in_channels)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet_plusplus(3,10, True)
# model.eval()
# input = torch.rand(1,3,128,128)
# output = model(input)
# output = output.data.cpu().numpy()
# print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/unet_urpc.py
================================================
from __future__ import division, print_function
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""
def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
def forward(self, x):
return self.conv_conv(x)
class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])
def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
return [x0, x1, x2, x3, x4]
class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""
def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, dropout_p)
)
def forward(self, x):
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class FeatureNoise(nn.Module):
def __init__(self, uniform_range=0.3):
super(FeatureNoise, self).__init__()
self.uni_dist = Uniform(-uniform_range, uniform_range)
def feature_based_noise(self, x):
noise_vector = self.uni_dist.sample(
x.shape[1:]).to(x.device).unsqueeze(0)
x_noise = x.mul(noise_vector) + x
return x_noise
def forward(self, x):
x = self.feature_based_noise(x)
return x
def Dropout(x, p=0.3):
x = torch.nn.functional.dropout(x, p)
return x
def FeatureDropout(x):
attention = torch.mean(x, dim=1, keepdim=True)
max_val, _ = torch.max(attention.view(
x.size(0), -1), dim=1, keepdim=True)
threshold = max_val * np.random.uniform(0.7, 0.9)
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
drop_mask = (attention < threshold).float()
x = x.mul(drop_mask)
return x
class Decoder_URPC(nn.Module):
def __init__(self, params):
super(Decoder_URPC, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)
self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)
# self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,
# kernel_size=3, padding=1)
self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,
kernel_size=3, padding=1)
self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,
kernel_size=3, padding=1)
self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,
kernel_size=3, padding=1)
# self.feature_noise = FeatureNoise()
def forward(self, feature, shape):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]
x = self.up1(x4, x3)
if self.training:
# dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5))
dp3_out_seg = self.out_conv_dp3(x)
else:
dp3_out_seg = self.out_conv_dp3(x)
dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)
x = self.up2(x, x2)
if self.training:
# dp2_out_seg = self.out_conv_dp2(FeatureDropout(x))
dp2_out_seg = self.out_conv_dp2(x)
else:
dp2_out_seg = self.out_conv_dp2(x)
dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)
x = self.up3(x, x1)
if self.training:
# dp1_out_seg = self.out_conv_dp1(self.feature_noise(x))
dp1_out_seg = self.out_conv_dp1(x)
else:
dp1_out_seg = self.out_conv_dp1(x)
dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)
x = self.up4(x, x0)
dp0_out_seg = self.out_conv(x)
return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg
class UNet_URPC(nn.Module):
def __init__(self, in_chns, class_num):
super(UNet_URPC, self).__init__()
params = {'in_chns': in_chns,
'feature_chns': [16, 32, 64, 128, 256],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}
self.encoder = Encoder(params)
self.decoder = Decoder_URPC(params)
def forward(self, x):
shape = x.shape[2:]
feature = self.encoder(x)
dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg = self.decoder(
feature, shape)
return dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg
def unet_urpc(in_channels, num_classes):
model = UNet_URPC(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet_urpc(1,10)
# model.eval()
# input = torch.rand(2,1,128,128)
# output, output1, output2, output3 = model(input)
# output = output1.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_2d/wavesnet.py
================================================
import numpy as np
import math, pywt
import torch
import torch.nn as nn
from torch.nn import Module
from torch.autograd import Function
from collections import OrderedDict
from itertools import islice
import operator
class My_DownSampling_SC(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)):
super(My_DownSampling_SC, self).__init__()
self.conv = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding)
def forward(self, input):
return self.conv(input), input
class My_DownSampling_MP(nn.Module):
def __init__(self, stride = 2, kernel_size = 2):
super(My_DownSampling_MP, self).__init__()
self.maxp = nn.MaxPool2d(kernel_size = kernel_size, stride = stride, return_indices = False)
def forward(self, input):
return self.maxp(input), input
class My_UpSampling_SC(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)):
super(My_UpSampling_SC, self).__init__()
self.conv = nn.ConvTranspose2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding)
def forward(self, input, feature_map):
return torch.cat((self.conv(input), feature_map), dim = 1)
class My_DownSampling_DWT(nn.Module):
def __init__(self, wavename = 'haar'):
super(My_DownSampling_DWT, self).__init__()
self.dwt = DWT_2D(wavename = wavename)
def forward(self, input):
LL, LH, HL, HH = self.dwt(input)
return LL, LH, HL, HH, input
class My_UpSampling_IDWT(nn.Module):
def __init__(self, wavename = 'haar'):
super(My_UpSampling_IDWT, self).__init__()
self.idwt = IDWT_2D(wavename = wavename)
def forward(self, LL, LH, HL, HH, feature_map):
return torch.cat((self.idwt(LL, LH, HL, HH), feature_map), dim = 1)
class My_Sequential(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
若某个模块输出多个数据,只将第一个数据往下传
"""
def __init__(self, *args):
super(My_Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx):
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError('index {} is out of range'.format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx, module):
key = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self):
return len(self._modules)
def __dir__(self):
keys = super(My_Sequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def forward(self, input):
self.output = []
for module in self._modules.values():
input = module(input)
if isinstance(input, tuple):
assert len(input) == 4 or len(input) == 2 or len(input) == 5
self.output.append(input[1:])
input = input[0]
if self.output != []:
return input, self.output
else:
return input
class My_Sequential_re(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
若某个模块输出多个数据,只将第一个数据往下传
"""
def __init__(self, *args):
super(My_Sequential_re, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
self.output = []
def _get_item_by_idx(self, iterator, idx):
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError('index {} is out of range'.format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx, module):
key = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self):
return len(self._modules)
def __dir__(self):
keys = super(My_Sequential_re, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def forward(self, *input):
LL = input[0]
index = 1
for module in self._modules.values():
if isinstance(module, My_UpSampling_IDWT):
LH = input[index]
HL = input[index + 1]
HH = input[index + 2]
feature_map = input[index + 3]
LL = module(LL, LH, HL, HH, feature_map = feature_map)
index += 4
elif isinstance(module, IDWT_2D) or 'idwt' in dir(module):
LH = input[index]
HL = input[index + 1]
HH = input[index + 2]
LL = module(LL, LH, HL, HH)
index += 3
elif isinstance(module, nn.MaxUnpool2d):
indices = input[index]
LL = module(input = LL, indices = indices)
#_, _, h, w = LL.size()
#LL = F.interpolate(LL, size = (2*h, 2*w), mode = 'bilinear', align_corners = True)
index += 1
elif isinstance(module, My_UpSampling_SC):
feature_map = input[index]
LL = module(input = LL, feature_map = feature_map)
index += 1
else:
LL = module(LL)
return LL
class DWTFunction_1D(Function):
@staticmethod
def forward(ctx, input, matrix_Low, matrix_High):
ctx.save_for_backward(matrix_Low, matrix_High)
L = torch.matmul(input, matrix_Low.t())
H = torch.matmul(input, matrix_High.t())
return L, H
@staticmethod
def backward(ctx, grad_L, grad_H):
matrix_L, matrix_H = ctx.saved_variables
grad_input = torch.add(torch.matmul(grad_L, matrix_L), torch.matmul(grad_H, matrix_H))
return grad_input, None, None
class IDWTFunction_1D(Function):
@staticmethod
def forward(ctx, input_L, input_H, matrix_L, matrix_H):
ctx.save_for_backward(matrix_L, matrix_H)
output = torch.add(torch.matmul(input_L, matrix_L), torch.matmul(input_H, matrix_H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_L, matrix_H = ctx.saved_variables
grad_L = torch.matmul(grad_output, matrix_L.t())
grad_H = torch.matmul(grad_output, matrix_H.t())
return grad_L, grad_H, None, None
class DWTFunction_2D(Function):
@staticmethod
def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.matmul(matrix_Low_0, input)
H = torch.matmul(matrix_High_0, input)
LL = torch.matmul(L, matrix_Low_1)
LH = torch.matmul(L, matrix_High_1)
HL = torch.matmul(H, matrix_Low_1)
HH = torch.matmul(H, matrix_High_1)
return LL, LH, HL, HH
@staticmethod
def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors
grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))
grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))
grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
return grad_input, None, None, None, None
class DWTFunction_2D_tiny(Function):
@staticmethod
def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.matmul(matrix_Low_0, input)
LL = torch.matmul(L, matrix_Low_1)
return LL
@staticmethod
def backward(ctx, grad_LL):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables
grad_L = torch.matmul(grad_LL, matrix_Low_1.t())
grad_input = torch.matmul(matrix_Low_0.t(), grad_L)
return grad_input, None, None, None, None
class IDWTFunction_2D(Function):
@staticmethod
def forward(ctx, input_LL, input_LH, input_HL, input_HH,
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)
L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))
H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))
output = torch.add(torch.matmul(matrix_Low_0.t(), L), torch.matmul(matrix_High_0.t(), H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors
grad_L = torch.matmul(matrix_Low_0, grad_output)
grad_H = torch.matmul(matrix_High_0, grad_output)
grad_LL = torch.matmul(grad_L, matrix_Low_1)
grad_LH = torch.matmul(grad_L, matrix_High_1)
grad_HL = torch.matmul(grad_H, matrix_Low_1)
grad_HH = torch.matmul(grad_H, matrix_High_1)
return grad_LL, grad_LH, grad_HL, grad_HH, None, None, None, None
class DWTFunction_3D(Function):
@staticmethod
def forward(ctx, input,
matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2)
L = torch.matmul(matrix_Low_0, input)
H = torch.matmul(matrix_High_0, input)
LL = torch.matmul(L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
LH = torch.matmul(L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
HL = torch.matmul(H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
HH = torch.matmul(H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0 = 2, dim1 = 3)
LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0 = 2, dim1 = 3)
LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0 = 2, dim1 = 3)
LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0 = 2, dim1 = 3)
HLL = torch.matmul(matrix_High_2, LL).transpose(dim0 = 2, dim1 = 3)
HLH = torch.matmul(matrix_High_2, LH).transpose(dim0 = 2, dim1 = 3)
HHL = torch.matmul(matrix_High_2, HL).transpose(dim0 = 2, dim1 = 3)
HHH = torch.matmul(matrix_High_2, HH).transpose(dim0 = 2, dim1 = 3)
return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH
@staticmethod
def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,
grad_HLL, grad_HLH, grad_HHL, grad_HHH):
matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))
grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))
grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
return grad_input, None, None, None, None, None, None, None, None
class IDWTFunction_3D(Function):
@staticmethod
def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,
input_HLL, input_HLH, input_HHL, input_HHH,
matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2):
ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
matrix_High_0, matrix_High_1, matrix_High_2)
input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)
input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))
input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))
output = torch.add(torch.matmul(matrix_Low_0.t(), input_L), torch.matmul(matrix_High_0.t(), input_H))
return output
@staticmethod
def backward(ctx, grad_output):
matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
grad_L = torch.matmul(matrix_Low_0, grad_output)
grad_H = torch.matmul(matrix_High_0, grad_output)
grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)
grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)
grad_LLL = torch.matmul(matrix_Low_2, grad_LL).transpose(dim0 = 2, dim1 = 3)
grad_LLH = torch.matmul(matrix_Low_2, grad_LH).transpose(dim0 = 2, dim1 = 3)
grad_LHL = torch.matmul(matrix_Low_2, grad_HL).transpose(dim0 = 2, dim1 = 3)
grad_LHH = torch.matmul(matrix_Low_2, grad_HH).transpose(dim0 = 2, dim1 = 3)
grad_HLL = torch.matmul(matrix_High_2, grad_LL).transpose(dim0 = 2, dim1 = 3)
grad_HLH = torch.matmul(matrix_High_2, grad_LH).transpose(dim0 = 2, dim1 = 3)
grad_HHL = torch.matmul(matrix_High_2, grad_HL).transpose(dim0 = 2, dim1 = 3)
grad_HHH = torch.matmul(matrix_High_2, grad_HH).transpose(dim0 = 2, dim1 = 3)
return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None
class DWT_1D(Module):
"""
input: (N, C, L)
output: L -- (N, C, L/2)
H -- (N, C, L/2)
"""
def __init__(self, wavename):
"""
:param band_low: 小波分解所用低频滤波器组
:param band_high: 小波分解所用高频滤波器组
"""
super(DWT_1D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.rec_lo
self.band_high = wavelet.rec_hi
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = self.input_height
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_h = matrix_h[:,(self.band_length_half-1):end]
matrix_g = matrix_g[:,(self.band_length_half-1):end]
if torch.cuda.is_available():
self.matrix_low = torch.tensor(matrix_h).cuda()
self.matrix_high = torch.tensor(matrix_g).cuda()
else:
self.matrix_low = torch.tensor(matrix_h)
self.matrix_high = torch.tensor(matrix_g)
def forward(self, input):
assert len(input.size()) == 3
self.input_height = input.size()[-1]
#assert self.input_height > self.band_length
self.get_matrix()
return DWTFunction_1D.apply(input, self.matrix_low, self.matrix_high)
class IDWT_1D(Module):
"""
input: L -- (N, C, L/2)
H -- (N, C, L/2)
output: (N, C, L)
"""
def __init__(self, wavename):
"""
:param band_low: 小波重建所需低频滤波器组
:param band_high: 小波重建所需高频滤波器组
"""
super(IDWT_1D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.dec_lo
self.band_high = wavelet.dec_hi
self.band_low.reverse()
self.band_high.reverse()
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = self.input_height
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_h = matrix_h[:,(self.band_length_half-1):end]
matrix_g = matrix_g[:,(self.band_length_half-1):end]
if torch.cuda.is_available():
self.matrix_low = torch.tensor(matrix_h).cuda()
self.matrix_high = torch.tensor(matrix_g).cuda()
else:
self.matrix_low = torch.tensor(matrix_h)
self.matrix_high = torch.tensor(matrix_g)
def forward(self, L, H):
assert len(L.size()) == len(H.size()) == 3
self.input_height = L.size()[-1] + H.size()[-1]
#assert self.input_height > self.band_length
self.get_matrix()
return IDWTFunction_1D.apply(L, H, self.matrix_low, self.matrix_high)
class DWT_2D(Module):
"""
input: (N, C, H, W)
output -- LL: (N, C, H/2, W/2)
LH: (N, C, H/2, W/2)
HL: (N, C, H/2, W/2)
HH: (N, C, H/2, W/2)
"""
def __init__(self, wavename):
"""
:param band_low: 小波分解所用低频滤波器组
:param band_high: 小波分解所用高频滤波器组
"""
super(DWT_2D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.rec_lo
self.band_high = wavelet.rec_hi
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = np.max((self.input_height, self.input_width))
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]
matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]
matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
if torch.cuda.is_available():
self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
else:
self.matrix_low_0 = torch.Tensor(matrix_h_0)
self.matrix_low_1 = torch.Tensor(matrix_h_1)
self.matrix_high_0 = torch.Tensor(matrix_g_0)
self.matrix_high_1 = torch.Tensor(matrix_g_1)
def forward(self, input):
assert isinstance(input, torch.Tensor)
assert len(input.size()) == 4
self.input_height = input.size()[-2]
self.input_width = input.size()[-1]
#assert self.input_height > self.band_length and self.input_width > self.band_length
self.get_matrix()
return DWTFunction_2D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)
class DWT_2D_tiny(Module):
"""
input: (N, C, H, W)
output -- LL: (N, C, H/2, W/2)
"""
def __init__(self, wavename):
"""
:param band_low: 小波分解所用低频滤波器组
:param band_high: 小波分解所用高频滤波器组
"""
super(DWT_2D_tiny, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.rec_lo
self.band_high = wavelet.rec_hi
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = np.max((self.input_height, self.input_width))
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]
matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]
matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
if torch.cuda.is_available():
self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
else:
self.matrix_low_0 = torch.Tensor(matrix_h_0)
self.matrix_low_1 = torch.Tensor(matrix_h_1)
self.matrix_high_0 = torch.Tensor(matrix_g_0)
self.matrix_high_1 = torch.Tensor(matrix_g_1)
def forward(self, input):
assert isinstance(input, torch.Tensor)
assert len(input.size()) == 4
self.input_height = input.size()[-2]
self.input_width = input.size()[-1]
self.get_matrix()
return DWTFunction_2D_tiny.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)
class IDWT_2D(Module):
"""
input -- LL: (N, C, H/2, W/2)
LH: (N, C, H/2, W/2)
HL: (N, C, H/2, W/2)
HH: (N, C, H/2, W/2)
output: (N, C, H, W)
"""
def __init__(self, wavename):
"""
:param band_low: 小波重建所需低频滤波器组
:param band_high: 小波重建所需高频滤波器组
"""
super(IDWT_2D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.dec_lo
self.band_low.reverse()
self.band_high = wavelet.dec_hi
self.band_high.reverse()
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = np.max((self.input_height, self.input_width))
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]
matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]
matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
if torch.cuda.is_available():
self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
else:
self.matrix_low_0 = torch.Tensor(matrix_h_0)
self.matrix_low_1 = torch.Tensor(matrix_h_1)
self.matrix_high_0 = torch.Tensor(matrix_g_0)
self.matrix_high_1 = torch.Tensor(matrix_g_1)
def forward(self, LL, LH, HL, HH):
assert len(LL.size()) == len(LH.size()) == len(HL.size()) == len(HH.size()) == 4
self.input_height = LL.size()[-2] + HH.size()[-2]
self.input_width = LL.size()[-1] + HH.size()[-1]
#assert self.input_height > self.band_length and self.input_width > self.band_length
self.get_matrix()
return IDWTFunction_2D.apply(LL, LH, HL, HH, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)
class DWT_3D(Module):
"""
input: (N, C, D, H, W)
output: -- LLL (N, C, D/2, H/2, W/2)
-- LLH (N, C, D/2, H/2, W/2)
-- LHL (N, C, D/2, H/2, W/2)
-- LHH (N, C, D/2, H/2, W/2)
-- HLL (N, C, D/2, H/2, W/2)
-- HLH (N, C, D/2, H/2, W/2)
-- HHL (N, C, D/2, H/2, W/2)
-- HHH (N, C, D/2, H/2, W/2)
"""
def __init__(self, wavename):
"""
:param band_low: 小波分解所用低频滤波器组
:param band_high: 小波分解所用高频滤波器组
"""
super(DWT_3D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.rec_lo
self.band_high = wavelet.rec_hi
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = np.max((self.input_height, self.input_width))
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]
matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]
matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)]
matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end]
matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end]
if torch.cuda.is_available():
self.matrix_low_0 = torch.tensor(matrix_h_0).cuda()
self.matrix_low_1 = torch.tensor(matrix_h_1).cuda()
self.matrix_low_2 = torch.tensor(matrix_h_2).cuda()
self.matrix_high_0 = torch.tensor(matrix_g_0).cuda()
self.matrix_high_1 = torch.tensor(matrix_g_1).cuda()
self.matrix_high_2 = torch.tensor(matrix_g_2).cuda()
else:
self.matrix_low_0 = torch.tensor(matrix_h_0)
self.matrix_low_1 = torch.tensor(matrix_h_1)
self.matrix_low_2 = torch.tensor(matrix_h_2)
self.matrix_high_0 = torch.tensor(matrix_g_0)
self.matrix_high_1 = torch.tensor(matrix_g_1)
self.matrix_high_2 = torch.tensor(matrix_g_2)
def forward(self, input):
assert len(input.size()) == 5
self.input_depth = input.size()[-3]
self.input_height = input.size()[-2]
self.input_width = input.size()[-1]
#assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length
self.get_matrix()
return DWTFunction_3D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)
class IDWT_3D(Module):
"""
input: -- LLL (N, C, D/2, H/2, W/2)
-- LLH (N, C, D/2, H/2, W/2)
-- LHL (N, C, D/2, H/2, W/2)
-- LHH (N, C, D/2, H/2, W/2)
-- HLL (N, C, D/2, H/2, W/2)
-- HLH (N, C, D/2, H/2, W/2)
-- HHL (N, C, D/2, H/2, W/2)
-- HHH (N, C, D/2, H/2, W/2)
output: (N, C, D, H, W)
"""
def __init__(self, wavename):
"""
:param band_low: 小波重构所用低频滤波器组
:param band_high: 小波重构所用高频滤波器组
"""
super(IDWT_3D, self).__init__()
wavelet = pywt.Wavelet(wavename)
self.band_low = wavelet.dec_lo
self.band_high = wavelet.dec_hi
self.band_low.reverse()
self.band_high.reverse()
assert len(self.band_low) == len(self.band_high)
self.band_length = len(self.band_low)
assert self.band_length % 2 == 0
self.band_length_half = math.floor(self.band_length / 2)
def get_matrix(self):
"""
生成变换矩阵
:return:
"""
L1 = np.max((self.input_height, self.input_width))
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) )
matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )
end = None if self.band_length_half == 1 else (-self.band_length_half+1)
index = 0
for i in range(L):
for j in range(self.band_length):
matrix_h[i, index+j] = self.band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(self.band_length):
matrix_g[i, index+j] = self.band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]
matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]
matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)]
matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end]
matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end]
if torch.cuda.is_available():
self.matrix_low_0 = torch.tensor(matrix_h_0).cuda()
self.matrix_low_1 = torch.tensor(matrix_h_1).cuda()
self.matrix_low_2 = torch.tensor(matrix_h_2).cuda()
self.matrix_high_0 = torch.tensor(matrix_g_0).cuda()
self.matrix_high_1 = torch.tensor(matrix_g_1).cuda()
self.matrix_high_2 = torch.tensor(matrix_g_2).cuda()
else:
self.matrix_low_0 = torch.tensor(matrix_h_0)
self.matrix_low_1 = torch.tensor(matrix_h_1)
self.matrix_low_2 = torch.tensor(matrix_h_2)
self.matrix_high_0 = torch.tensor(matrix_g_0)
self.matrix_high_1 = torch.tensor(matrix_g_1)
self.matrix_high_2 = torch.tensor(matrix_g_2)
def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):
assert len(LLL.size()) == len(LLH.size()) == len(LHL.size()) == len(LHH.size()) == 5
assert len(HLL.size()) == len(HLH.size()) == len(HHL.size()) == len(HHH.size()) == 5
self.input_depth = LLL.size()[-3] + HHH.size()[-3]
self.input_height = LLL.size()[-2] + HHH.size()[-2]
self.input_width = LLL.size()[-1] + HHH.size()[-1]
#assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length
self.get_matrix()
return IDWTFunction_3D.apply(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH,
self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)
# if __name__ == '__main__':
# import pywt, cv2
# from datetime import datetime
#
# wavelet = pywt.Wavelet('bior1.1')
# h = wavelet.rec_lo
# g = wavelet.rec_hi
# h_ = wavelet.dec_lo
# g_ = wavelet.dec_hi
# h_.reverse()
# g_.reverse()
#
# #"""
# image_full_name = '/home/liqiufu/Pictures/standard_test_images/lena_color_512.tif'
# image = cv2.imread(image_full_name, flags = 1)
# image = image[0:512,0:512,:]
# print(image.shape)
# height, width, channel = image.shape
# #image = image.reshape((1,height,width))
# t0 = datetime.now()
# for index in range(1):
# m0 = DWT_2D(wavename = 'haar')
#
# m1 = IDWT_2D(wavename = 'haar')
# print(isinstance(m1, IDWT_2D))
# t1 = datetime.now()
class SegNet_VGG(nn.Module):
def __init__(self, features, num_classes = 21, init_weights = True, wavename = None):
super(SegNet_VGG, self).__init__()
self.features = features[0]
self.decoders = features[1]
self.classifier_seg = nn.Sequential(
#nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
#nn.ReLU(True),
nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
xx = self.features(x)
x, [(indices_1,), (indices_2,), (indices_3,), (indices_4,), (indices_5,)] = xx
x = self.decoders(x, indices_5, indices_4, indices_3, indices_2, indices_1)
x = self.classifier_seg(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None):
# don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
else:
print('Not initializing')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def __str__(self):
return 'SegNet_VGG'
class WSegNet_VGG(nn.Module):
def __init__(self, features, num_classes, init_weights = True, wavename = None):
super(WSegNet_VGG, self).__init__()
self.features = features[0]
self.decoders = features[1]
self.classifier_seg = nn.Sequential(
#nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
#nn.ReLU(True),
nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
xx = self.features(x)
x, [(LH1,HL1,HH1), (LH2,HL2,HH2,), (LH3,HL3,HH3,), (LH4,HL4,HH4,), (LH5,HL5,HH5,)] = xx
x = self.decoders(x, LH5,HL5,HH5, LH4,HL4,HH4, LH3,HL3,HH3, LH2,HL2,HH2, LH1,HL1,HH1)
x = self.classifier_seg(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None):
# don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
else:
print('Not initializing')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def __str__(self):
return 'WSegNet_VGG'
def make_layers(cfg, batch_norm = False):
encoder = []
in_channels = 3
for v in cfg:
if v != 'M':
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
encoder += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
elif v == 'M':
encoder += [nn.MaxPool2d(kernel_size = 2, stride = 2, return_indices = True)]
encoder = My_Sequential(*encoder)
decoder = []
cfg.reverse()
out_channels_final = 64
for index, v in enumerate(cfg):
if index != len(cfg) - 1:
out_channels = cfg[index + 1]
else:
out_channels = out_channels_final
if out_channels == 'M':
out_channels = cfg[index + 2]
if v == 'M':
decoder += [nn.MaxUnpool2d(kernel_size = 2, stride = 2)]
else:
conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1)
if batch_norm:
decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)]
else:
decoder += [conv2d, nn.ReLU(inplace = True)]
decoder = My_Sequential_re(*decoder)
return encoder, decoder
def make_w_layers(cfg, in_channels, batch_norm = False, wavename = 'haar'):
encoder = []
for v in cfg:
if v != 'M':
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
encoder += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
elif v == 'M':
encoder += [DWT_2D(wavename = wavename)]
encoder = My_Sequential(*encoder)
decoder = []
cfg.reverse()
out_channels_final = 64
for index, v in enumerate(cfg):
if index != len(cfg) - 1:
out_channels = cfg[index + 1]
else:
out_channels = out_channels_final
if out_channels == 'M':
out_channels = cfg[index + 2]
if v == 'M':
decoder += [IDWT_2D(wavename = wavename)]
else:
conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1)
if batch_norm:
decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)]
else:
decoder += [conv2d, nn.ReLU(inplace = True)]
decoder = My_Sequential_re(*decoder)
return encoder, decoder
cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 11 layers
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 13 layers
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], # 16 layers out_channels for encoder, input_channels for decoder
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], # 19 layers
}
def segnet_vgg11(pretrained = False, **kwargs):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['A']), **kwargs)
return model
def segnet_vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['A'], batch_norm = True), **kwargs)
return model
def segnet_vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['B']), **kwargs)
return model
def segnet_vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
return model
def segnet_vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['D']), **kwargs)
return model
def segnet_vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
return model
def segnet_vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['E']), **kwargs)
return model
def segnet_vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = SegNet_VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
return model
"""================================================================================="""
def wsegnet_vgg11(pretrained = False, wavename = 'haar', **kwargs):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['A'], wavename = wavename), **kwargs)
return model
def wsegnet_vgg11_bn(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['A'], batch_norm = True, wavename = wavename), **kwargs)
return model
def wsegnet_vgg13(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['B'], wavename = wavename), **kwargs)
return model
def wsegnet_vgg13_bn(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['B'], batch_norm=True, wavename = wavename), **kwargs)
return model
def wsegnet_vgg16(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['D'], wavename = wavename), **kwargs)
return model
def wsegnet_vgg16_bn(in_channels, num_classes, pretrained=False, wavename = 'haar', **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['D'], in_channels, batch_norm=True, wavename = wavename), num_classes, **kwargs)
return model
def wsegnet_vgg19(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['E'], wavename = wavename), **kwargs)
return model
def wsegnet_vgg19_bn(pretrained=False, wavename = 'haar', **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs['init_weights'] = False
model = WSegNet_VGG(make_w_layers(cfg['E'], batch_norm=True, wavename = wavename), **kwargs)
return model
# if __name__ == '__main__':
# from loss.loss_function import segmentation_loss
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 128, 128).long()
# model = wsegnet_vgg16_bn(1, 5)
# model.train()
# input1 = torch.rand(2, 1, 128, 128)
# y = model(input1)
# loss_train = criterion(y, mask)
# loss_train.backward()
# # print(output)
# print(y.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: models/networks_2d/wds.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
class basic_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(basic_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.block(x)
return x
class WDS(nn.Module):
def __init__(self, in_channels, num_classes):
super(WDS, self).__init__()
# branch1
self.b1_1 = basic_block(in_channels, 64)
self.b1_2 = basic_block(64, 64)
self.b1_3 = basic_block(64, 64)
self.b1_4 = basic_block(64, 64)
self.b1_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.b1_6 = basic_block(64, 128)
self.b1_7 = basic_block(128, 128)
self.b1_8 = basic_block(128, 128)
self.b1_9 = basic_block(128, 128)
self.b1_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# branch2
self.b2_1 = basic_block(in_channels, 64)
self.b2_2 = basic_block(64, 64)
self.b2_3 = basic_block(64, 64)
self.b2_4 = basic_block(64, 64)
self.b2_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.b2_6 = basic_block(64, 128)
self.b2_7 = basic_block(128, 128)
self.b2_8 = basic_block(128, 128)
self.b2_9 = basic_block(128, 128)
self.b2_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# branch3
self.b3_1 = basic_block(in_channels, 64)
self.b3_2 = basic_block(64, 64)
self.b3_3 = basic_block(64, 64)
self.b3_4 = basic_block(64, 64)
self.b3_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.b3_6 = basic_block(64, 128)
self.b3_7 = basic_block(128, 128)
self.b3_8 = basic_block(128, 128)
self.b3_9 = basic_block(128, 128)
self.b3_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# branch4
self.b4_1 = basic_block(in_channels, 64)
self.b4_2 = basic_block(64, 64)
self.b4_3 = basic_block(64, 64)
self.b4_4 = basic_block(64, 64)
self.b4_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.b4_6 = basic_block(64, 128)
self.b4_7 = basic_block(128, 128)
self.b4_8 = basic_block(128, 128)
self.b4_9 = basic_block(128, 128)
self.b4_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# output
self.output_layer = nn.Sequential(
nn.Conv2d(128*4, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1, bias=False),
)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, LL, LH, HL, HH):
# H, W = 2*LL.shape[2], 2*LL.shape[3]
H, W = LL.shape[2], LL.shape[3]
LL = self.b1_1(LL)
LL = self.b1_2(LL)
LL = self.b1_3(LL)
LL = self.b1_4(LL)
LL = self.b1_5(LL)
LL = self.b1_6(LL)
LL = self.b1_7(LL)
LL = self.b1_8(LL)
LL = self.b1_9(LL)
LL = self.b1_10(LL)
LH = self.b2_1(LH)
LH = self.b2_2(LH)
LH = self.b2_3(LH)
LH = self.b2_4(LH)
LH = self.b2_5(LH)
LH = self.b2_6(LH)
LH = self.b2_7(LH)
LH = self.b2_8(LH)
LH = self.b2_9(LH)
LH = self.b2_10(LH)
HL = self.b3_1(HL)
HL = self.b3_2(HL)
HL = self.b3_3(HL)
HL = self.b3_4(HL)
HL = self.b3_5(HL)
HL = self.b3_6(HL)
HL = self.b3_7(HL)
HL = self.b3_8(HL)
HL = self.b3_9(HL)
HL = self.b3_10(HL)
HH = self.b4_1(HH)
HH = self.b4_2(HH)
HH = self.b4_3(HH)
HH = self.b4_4(HH)
HH = self.b4_5(HH)
HH = self.b4_6(HH)
HH = self.b4_7(HH)
HH = self.b4_8(HH)
HH = self.b4_9(HH)
HH = self.b4_10(HH)
x = torch.cat((LL, LH, HL, HH), dim=1)
x = self.output_layer(x)
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
return x
if __name__ == '__main__':
from loss.loss_function import segmentation_loss
criterion = segmentation_loss('dice', False)
mask = torch.ones(2, 128, 128).long()
model = WDS(1, 5)
model.train()
input1 = torch.rand(2, 1, 128, 128)
y = model(input1, input1, input1, input1)
loss_train = criterion(y, mask)
loss_train.backward()
# print(output)
print(y.data.cpu().numpy().shape)
print(loss_train)
================================================
FILE: models/networks_2d/xnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
BatchNorm2d = nn.BatchNorm2d
relu_inplace = True
BN_MOMENTUM = 0.1
# BN_MOMENTUM = 0.01
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
BatchNorm2d(ch_out, momentum=BN_MOMENTUM),
nn.ReLU(inplace=relu_inplace)
)
def forward(self, x):
x = self.up(x)
return x
class down_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(down_conv, self).__init__()
self.down = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False),
BatchNorm2d(ch_out, momentum=BN_MOMENTUM),
nn.ReLU(inplace=relu_inplace)
)
def forward(self, x):
x = self.down(x)
return x
class same_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(same_conv, self).__init__()
self.same = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(ch_out, momentum=BN_MOMENTUM),
nn.ReLU(inplace=relu_inplace))
def forward(self, x):
x = self.same(x)
return x
class transition_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(transition_conv, self).__init__()
self.transition = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False),
BatchNorm2d(ch_out, momentum=BN_MOMENTUM),
nn.ReLU(inplace=relu_inplace))
def forward(self, x):
x = self.transition(x)
return x
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=relu_inplace)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
# out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.bn2(out) + identity
out = self.relu(out)
return out
class DoubleBasicBlock(nn.Module):
def __init__(self, inplanes, planes, downsample=None):
super(DoubleBasicBlock, self).__init__()
self.DBB = nn.Sequential(
BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample),
BasicBlock(inplanes=planes, planes=planes)
)
def forward(self, x):
out = self.DBB(x)
return out
class XNet(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = DoubleBasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b2_4_5 = DoubleBasicBlock(l4c, l4c)
self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4_1 = self.b1_3_2_down(x1_3)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4_1 = self.b2_3_2_down(x2_3)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
# decode
# branch1
x1_3 = torch.cat((x1_3, x1_4_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_3 = torch.cat((x2_3, x2_4_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_1_1_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_1_1_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_4_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_4_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4 = self.b1_3_2_down(x1_3)
x1_4 = self.b1_4_1(x1_4)
x1_5_1 = self.b1_4_2_down(x1_4)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4 = self.b2_3_2_down(x2_3)
x2_4 = self.b2_4_1(x2_4)
x2_5_1 = self.b2_4_2_down(x2_4)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
# decode
# branch1
x1_4 = torch.cat((x1_4, x1_5_3), dim=1)
x1_4 = self.b1_4_3(x1_4)
x1_4 = self.b1_4_4_up(x1_4)
x1_3 = torch.cat((x1_3, x1_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_4 = torch.cat((x2_4, x2_5_3), dim=1)
x2_4 = self.b2_4_3(x2_4)
x2_4 = self.b2_4_4_up(x2_4)
x2_3 = torch.cat((x2_3, x2_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_1_2_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_1_2_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_4_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = DoubleBasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c, l4c)
self.b2_4_5 = DoubleBasicBlock(l4c, l4c)
self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4 = self.b1_3_2_down(x1_3)
x1_4 = self.b1_4_1(x1_4)
x1_5_1 = self.b1_4_2_down(x1_4)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4_1 = self.b2_3_2_down(x2_3)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
# decode
# branch1
x1_4 = torch.cat((x1_4, x1_5_3), dim=1)
x1_4 = self.b1_4_3(x1_4)
x1_4 = self.b1_4_4_up(x1_4)
x1_3 = torch.cat((x1_3, x1_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_3 = torch.cat((x2_3, x2_4_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_2_1_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_2_1_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_4_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4_1 = self.b1_3_2_down(x1_3)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4 = self.b2_3_2_down(x2_3)
x2_4 = self.b2_4_1(x2_4)
x2_5_1 = self.b2_4_2_down(x2_4)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
# decode
# branch1
x1_3 = torch.cat((x1_3, x1_4_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_4 = torch.cat((x2_4, x2_5_3), dim=1)
x2_4 = self.b2_4_3(x2_4)
x2_4 = self.b2_4_4_up(x2_4)
x2_3 = torch.cat((x2_3, x2_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_2_3_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_2_3_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_3_up = up_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_up_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_2 = DoubleBasicBlock(l3c, l3c)
self.b2_3_3_down = down_conv(l3c, l3c)
self.b2_3_3_down_down = down_conv(l3c, l3c)
self.b2_3_3_same = same_conv(l3c, l3c)
self.b2_3_4_transition = transition_conv(l3c+l5c+l4c, l3c)
self.b2_3_5 = DoubleBasicBlock(l3c, l3c)
self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_7_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = DoubleBasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b2_4_5 = DoubleBasicBlock(l4c, l4c)
self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4_1 = self.b1_3_2_down(x1_3)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_4_3_up = self.b1_4_3_up(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3_1 = self.b2_2_2_down(x2_2)
x2_3_1 = self.b2_3_1(x2_3_1)
x2_3_2 = self.b2_3_2(x2_3_1)
x2_3_3_down = self.b2_3_3_down(x2_3_2)
x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down)
x2_3_3_same = self.b2_3_3_same(x2_3_2)
x2_4_1 = self.b2_3_2_down(x2_3_1)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_4_3_down, x1_5_2_same), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
x2_3_4 = torch.cat((x2_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1)
x2_3_4 = self.b2_3_4_transition(x2_3_4)
x2_3_4 = self.b2_3_5(x2_3_4)
x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1)
x2_3_4 = self.b2_3_6(x2_3_4)
x2_3_4 = self.b2_3_7_up(x2_3_4)
# decode
# branch1
x1_3 = torch.cat((x1_3, x1_4_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_2 = torch.cat((x2_2, x2_3_4), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_3_2_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_3_2_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_2 = DoubleBasicBlock(l3c, l3c)
self.b1_3_3_down = down_conv(l3c, l3c)
self.b1_3_3_down_down = down_conv(l3c, l3c)
self.b1_3_3_same = same_conv(l3c, l3c)
self.b1_3_4_transition = transition_conv(l3c+l5c+l4c, l3c)
self.b1_3_5 = DoubleBasicBlock(l3c, l3c)
self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_7_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = DoubleBasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_3_up = up_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)
self.b2_4_5 = DoubleBasicBlock(l4c, l4c)
self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_up_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3_1 = self.b1_2_2_down(x1_2)
x1_3_1 = self.b1_3_1(x1_3_1)
x1_3_2 = self.b1_3_2(x1_3_1)
x1_3_3_down = self.b1_3_3_down(x1_3_2)
x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down)
x1_3_3_same = self.b1_3_3_same(x1_3_2)
x1_4_1 = self.b1_3_2_down(x1_3_1)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4_1 = self.b2_3_2_down(x2_3)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_4_3_up = self.b2_4_3_up(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_4_3_down, x2_5_2_same), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
x1_3_4 = torch.cat((x1_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1)
x1_3_4 = self.b1_3_4_transition(x1_3_4)
x1_3_4 = self.b1_3_5(x1_3_4)
x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1)
x1_3_4 = self.b1_3_6(x1_3_4)
x1_3_4 = self.b1_3_7_up(x1_3_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
# decode
# branch1
x1_2 = torch.cat((x1_2, x1_3_4), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_3 = torch.cat((x2_3, x2_4_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_3_3_m(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_3_3_m, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_2 = DoubleBasicBlock(l3c, l3c)
self.b1_3_3_down = down_conv(l3c, l3c)
self.b1_3_3_down_down = down_conv(l3c, l3c)
self.b1_3_3_same = same_conv(l3c, l3c)
self.b1_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c)
self.b1_3_5 = DoubleBasicBlock(l3c, l3c)
self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_7_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_3_up = up_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_up_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = DoubleBasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = DoubleBasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_2 = DoubleBasicBlock(l3c, l3c)
self.b2_3_3_down = down_conv(l3c, l3c)
self.b2_3_3_down_down = down_conv(l3c, l3c)
self.b2_3_3_same = same_conv(l3c, l3c)
self.b2_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c)
self.b2_3_5 = DoubleBasicBlock(l3c, l3c)
self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_7_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = DoubleBasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = DoubleBasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_3_up = up_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)
self.b2_4_5 = DoubleBasicBlock(l4c, l4c)
self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = DoubleBasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_up_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)
self.b2_5_4 = DoubleBasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3_1 = self.b1_2_2_down(x1_2)
x1_3_1 = self.b1_3_1(x1_3_1)
x1_3_2 = self.b1_3_2(x1_3_1)
x1_3_3_down = self.b1_3_3_down(x1_3_2)
x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down)
x1_3_3_same = self.b1_3_3_same(x1_3_2)
x1_4_1 = self.b1_3_2_down(x1_3_1)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_4_3_up = self.b1_4_3_up(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3_1 = self.b2_2_2_down(x2_2)
x2_3_1 = self.b2_3_1(x2_3_1)
x2_3_2 = self.b2_3_2(x2_3_1)
x2_3_3_down = self.b2_3_3_down(x2_3_2)
x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down)
x2_3_3_same = self.b2_3_3_same(x2_3_2)
x2_4_1 = self.b2_3_2_down(x2_3_1)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_4_3_up = self.b2_4_3_up(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
x1_3_4 = torch.cat((x1_3_3_same, x2_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1)
x1_3_4 = self.b1_3_4_transition(x1_3_4)
x1_3_4 = self.b1_3_5(x1_3_4)
x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1)
x1_3_4 = self.b1_3_6(x1_3_4)
x1_3_4 = self.b1_3_7_up(x1_3_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
x2_3_4 = torch.cat((x2_3_3_same, x1_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1)
x2_3_4 = self.b2_3_4_transition(x2_3_4)
x2_3_4 = self.b2_3_5(x2_3_4)
x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1)
x2_3_4 = self.b2_3_6(x2_3_4)
x2_3_4 = self.b2_3_7_up(x2_3_4)
# decode
# branch1
x1_2 = torch.cat((x1_2, x1_3_4), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_2 = torch.cat((x2_2, x2_3_4), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
class XNet_sb(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet_sb, self).__init__()
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = DoubleBasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = DoubleBasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = DoubleBasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = DoubleBasicBlock(l4c, l4c)
# self.b1_4_3_down = down_conv(l4c, l4c)
# self.b1_4_3_same = same_conv(l4c, l4c)
# self.b1_4_4_transition = transition_conv(l4c, l4c)
self.b1_4_5 = DoubleBasicBlock(l4c, l4c)
self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = DoubleBasicBlock(l5c, l5c)
# self.b1_5_2_up = up_conv(l5c, l5c)
# self.b1_5_2_same = same_conv(l5c, l5c)
# self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b1_5_4 = DoubleBasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4_1 = self.b1_3_2_down(x1_3)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_2 = self.b1_4_5(x1_4_2)
# x1_4_3_down = self.b1_4_3_down(x1_4_2)
# x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_1 = self.b1_5_4(x1_5_1)
x1_5_1 = self.b1_5_5_up(x1_5_1)
# x1_5_2_up = self.b1_5_2_up(x1_5_1)
# x1_5_2_same = self.b1_5_2_same(x1_5_1)
# decode
# branch1
x1_4_2 = torch.cat((x1_4_2, x1_5_1), dim=1)
x1_4_2 = self.b1_4_6(x1_4_2)
x1_4_2 = self.b1_4_7_up(x1_4_2)
x1_3 = torch.cat((x1_3, x1_4_2), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
return x1_1
# if __name__ == '__main__':
# model = XNet(1, 10)
# total = sum([param.nelement() for param in model.parameters()])
# from thop import profile, clever_format
#
# input = torch.randn(1, 1, 128, 128)
# flops, params = profile(model, inputs=(input, input, ))
# macs, params = clever_format([flops, params], "%.3f")
# print(macs)
# print(params)
# print(total)
# model.eval()
# input1 = torch.rand(2,3,256,256)
# input2 = torch.rand(2,1,256,256)
# x1_1, x2_1 = model(input1, input2)
# output1 = x1_1.data.cpu().numpy()
# output2 = x2_1.data.cpu().numpy()
# # print(output)
# print(output1.shape)
# print(output2.shape)
================================================
FILE: models/networks_3d/__init__.py
================================================
================================================
FILE: models/networks_3d/conresnet.py
================================================
import torch.nn as nn
from torch.nn import functional as F
import torch
import numpy as np
from torch.nn import init
# from loss.loss_function import segmentation_loss
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class Conv3d(nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False):
super(Conv3d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)
weight = weight - weight_mean
std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)
weight = weight / std.expand_as(weight)
return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv3x3x3(in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1,1,1), dilation=(1,1,1), bias=False,
weight_std=False):
"3x3x3 convolution with padding"
if weight_std:
return Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
bias=bias)
else:
return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias)
class ConResAtt(nn.Module):
def __init__(self, in_channels, in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),
dilation=(1, 1, 1), bias=False, weight_std=False, first_layer=False):
super(ConResAtt, self).__init__()
self.weight_std = weight_std
self.stride = stride
self.in_planes = in_planes
self.out_planes = out_planes
self.first_layer = first_layer
self.relu = nn.ReLU(inplace=True)
self.gn_seg = nn.GroupNorm(8, in_planes)
self.conv_seg = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),
stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
self.gn_res = nn.GroupNorm(8, out_planes)
self.conv_res = conv3x3x3(out_planes, out_planes, kernel_size=(1,1,1),
stride=(1, 1, 1), padding=(0,0,0),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
self.gn_res1 = nn.GroupNorm(8, out_planes)
self.conv_res1 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),
stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
self.gn_res2 = nn.GroupNorm(8, out_planes)
self.conv_res2 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),
stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
self.gn_mp = nn.GroupNorm(8, in_planes)
self.conv_mp_first = conv3x3x3(in_channels, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),
stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
self.conv_mp = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),
stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),
dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)
def _res(self, x): # bs, channel, D, W, H
bs, channel, depth, heigt, width = x.shape
# x_copy = torch.zeros_like(x).cuda()
x_copy = torch.zeros_like(x)
x_copy[:, :, 1:, :, :] = x[:, :, 0: depth - 1, :, :]
res = x - x_copy
res[:, :, 0, :, :] = 0
res = torch.abs(res)
return res
def forward(self, input):
x1, x2 = input
if self.first_layer:
x1 = self.gn_seg(x1)
x1 = self.relu(x1)
x1 = self.conv_seg(x1)
res = torch.sigmoid(x1)
res = self._res(res)
res = self.conv_res(res)
x2 = self.conv_mp_first(x2)
x2 = x2 + res
else:
x1 = self.gn_seg(x1)
x1 = self.relu(x1)
x1 = self.conv_seg(x1)
res = torch.sigmoid(x1)
res = self._res(res)
res = self.conv_res(res)
if self.in_planes != self.out_planes:
x2 = self.gn_mp(x2)
x2 = self.relu(x2)
x2 = self.conv_mp(x2)
x2 = x2 + res
x2 = self.gn_res1(x2)
x2 = self.relu(x2)
x2 = self.conv_res1(x2)
x1 = x1*(1 + torch.sigmoid(x2))
return [x1, x2]
class NoBottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=(1, 1, 1), dilation=(1, 1, 1), downsample=None, fist_dilation=1,
multi_grid=1, weight_std=False):
super(NoBottleneck, self).__init__()
self.weight_std = weight_std
self.relu = nn.ReLU(inplace=True)
self.gn1 = nn.GroupNorm(8, inplanes)
self.conv1 = conv3x3x3(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=dilation * multi_grid,
dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)
self.gn2 = nn.GroupNorm(8, planes)
self.conv2 = conv3x3x3(planes, planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=dilation * multi_grid,
dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
skip = x
seg = self.gn1(x)
seg = self.relu(seg)
seg = self.conv1(seg)
seg = self.gn2(seg)
seg = self.relu(seg)
seg = self.conv2(seg)
if self.downsample is not None:
skip = self.downsample(x)
seg = seg + skip
return seg
class ConResNet(nn.Module):
def __init__(self, in_channels, num_classes, shape, block, layers, weight_std=False):
self.shape = shape
self.weight_std = weight_std
super(ConResNet, self).__init__()
self.conv_4_32 = nn.Sequential(
conv3x3x3(in_channels, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), weight_std=self.weight_std))
self.conv_32_64 = nn.Sequential(
nn.GroupNorm(8, 32),
nn.ReLU(inplace=True),
conv3x3x3(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))
self.conv_64_128 = nn.Sequential(
nn.GroupNorm(8, 64),
nn.ReLU(inplace=True),
conv3x3x3(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))
self.conv_128_256 = nn.Sequential(
nn.GroupNorm(8, 128),
nn.ReLU(inplace=True),
conv3x3x3(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))
self.layer0 = self._make_layer(block, 32, 32, layers[0], stride=(1, 1, 1))
self.layer1 = self._make_layer(block, 64, 64, layers[1], stride=(1, 1, 1))
self.layer2 = self._make_layer(block, 128, 128, layers[2], stride=(1, 1, 1))
self.layer3 = self._make_layer(block, 256, 256, layers[3], stride=(1, 1, 1))
self.layer4 = self._make_layer(block, 256, 256, layers[4], stride=(1, 1, 1), dilation=(2,2,2))
self.fusionConv = nn.Sequential(
nn.GroupNorm(8, 256),
nn.ReLU(inplace=True),
nn.Dropout3d(0.1),
conv3x3x3(256, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1), weight_std=self.weight_std)
)
self.seg_x4 = nn.Sequential(
ConResAtt(in_channels, 128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std, first_layer=True))
self.seg_x2 = nn.Sequential(
ConResAtt(in_channels, 64, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std))
self.seg_x1 = nn.Sequential(
ConResAtt(in_channels, 32, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std))
self.seg_cls = nn.Sequential(
nn.Conv3d(32, num_classes, kernel_size=1)
)
self.res_cls = nn.Sequential(
nn.Conv3d(32, num_classes, kernel_size=1)
)
self.resx2_cls = nn.Sequential(
nn.Conv3d(32, num_classes, kernel_size=1)
)
self.resx4_cls = nn.Sequential(
nn.Conv3d(64, num_classes, kernel_size=1)
)
def _make_layer(self, block, inplanes, outplanes, blocks, stride=(1, 1, 1), dilation=(1, 1, 1), multi_grid=1):
downsample = None
if stride[0] != 1 or stride[1] != 1 or stride[2] != 1 or inplanes != outplanes:
downsample = nn.Sequential(
nn.GroupNorm(8, inplanes),
nn.ReLU(inplace=True),
conv3x3x3(inplanes, outplanes, kernel_size=(1, 1, 1), stride=stride, padding=(0, 0, 0),
weight_std=self.weight_std)
)
layers = []
generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1
layers.append(block(inplanes, outplanes, stride, dilation=dilation, downsample=downsample,
multi_grid=generate_multi_grid(0, multi_grid), weight_std=self.weight_std))
for i in range(1, blocks):
layers.append(
block(inplanes, outplanes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid),
weight_std=self.weight_std))
return nn.Sequential(*layers)
def forward(self, x, x_res):
## encoder
x = self.conv_4_32(x)
x = self.layer0(x)
skip1 = x
x = self.conv_32_64(x)
x = self.layer1(x)
skip2 = x
x = self.conv_64_128(x)
x = self.layer2(x)
skip3 = x
x = self.conv_128_256(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.fusionConv(x)
## decoder
res_x4 = F.interpolate(x_res, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True)
seg_x4 = F.interpolate(x, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True)
seg_x4 = seg_x4 + skip3
seg_x4, res_x4 = self.seg_x4([seg_x4, res_x4])
res_x2 = F.interpolate(res_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True)
seg_x2 = F.interpolate(seg_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True)
seg_x2 = seg_x2 + skip2
seg_x2, res_x2 = self.seg_x2([seg_x2, res_x2])
res_x1 = F.interpolate(res_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)
seg_x1 = F.interpolate(seg_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)
seg_x1 = seg_x1 + skip1
seg_x1, res_x1 = self.seg_x1([seg_x1, res_x1])
seg = self.seg_cls(seg_x1)
res = self.res_cls(res_x1)
resx2 = self.resx2_cls(res_x2)
resx4 = self.resx4_cls(res_x4)
resx2 = F.interpolate(resx2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)
resx4 = F.interpolate(resx4, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)
return [seg, res, resx2, resx4]
def conresnet(in_channels, num_classes, **kwargs):
model = ConResNet(in_channels, num_classes, kwargs['img_shape'], NoBottleneck, [1, 2, 2, 2, 2])
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(5, 64, 64, 64).long()
# model = conresnet(1, 10, img_shape=(64, 64, 64))
# model.train()
# output = model(torch.rand(5, 1, 64, 64, 64), torch.rand(5, 1, 64, 64, 64))
#
# loss_train_1 = criterion(output[0], mask)
# loss_train_2 = criterion(output[1], mask)
# loss_train_3 = criterion(output[2], mask)
# loss_train_4 = criterion(output[3], mask)
#
# loss_train_1.backward()
#
# print(output[0].data.cpu().numpy().shape)
# print(output[1].data.cpu().numpy().shape)
# print(output[2].data.cpu().numpy().shape)
# print(output[3].data.cpu().numpy().shape)
# print(loss_train_1)
# print(loss_train_2)
# print(loss_train_3)
# print(loss_train_4)
================================================
FILE: models/networks_3d/cotr.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.init import xavier_uniform_, constant_, normal_
import copy
import math
# from loss.loss_function import segmentation_loss
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=[64, 64, 64], temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x):
bs, c, d, h, w = x.shape
mask = torch.zeros(bs, d, h, w, dtype=torch.bool).cuda()
# mask = torch.zeros(bs, d, h, w, dtype=torch.bool)
assert mask is not None
not_mask = ~mask
d_embed = not_mask.cumsum(1, dtype=torch.float32)
y_embed = not_mask.cumsum(2, dtype=torch.float32)
x_embed = not_mask.cumsum(3, dtype=torch.float32)
if self.normalize:
eps = 1e-6
d_embed = (d_embed - 0.5) / (d_embed[:, -1:, :, :] + eps) * self.scale
y_embed = (y_embed - 0.5) / (y_embed[:, :, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats[0], dtype=torch.float32, device=x.device)
dim_tx = self.temperature ** (3 * (dim_tx // 3) / self.num_pos_feats[0])
dim_ty = torch.arange(self.num_pos_feats[1], dtype=torch.float32, device=x.device)
dim_ty = self.temperature ** (3 * (dim_ty // 3) / self.num_pos_feats[1])
dim_td = torch.arange(self.num_pos_feats[2], dtype=torch.float32, device=x.device)
dim_td = self.temperature ** (3 * (dim_td // 3) / self.num_pos_feats[2])
pos_x = x_embed[:, :, :, :, None] / dim_tx
pos_y = y_embed[:, :, :, :, None] / dim_ty
pos_d = d_embed[:, :, :, :, None] / dim_td
pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos_d = torch.stack((pos_d[:, :, :, :, 0::2].sin(), pos_d[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos = torch.cat((pos_d, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3)
return pos
def build_position_encoding(mode, hidden_dim):
N_steps = hidden_dim // 3
if (hidden_dim % 3) != 0:
N_steps = [N_steps, N_steps, N_steps + hidden_dim % 3]
else:
N_steps = [N_steps, N_steps, N_steps]
if mode in ('v2', 'sine'):
position_embedding = PositionEmbeddingSine(num_pos_feats=N_steps, normalize=True)
else:
raise ValueError(f"not supported {mode}")
return position_embedding
def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights):
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
# sampling_grids = 3 * sampling_locations - 1
sampling_value_list = []
for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes):
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_)
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:]
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_.to(dtype=value_l_.dtype), mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0]
sampling_value_list.append(sampling_value_l_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 3)
:param input_flatten (N, \sum_{l=0}^{L-1} D_l \cdot H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 3), [(D_0, H_0, W_0), (D_1, H_1, W_1), ..., (D_{L-1}, H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, D_0*H_0*W_0, D_0*H_0*W_0+D_1*H_1*W_1, D_0*H_0*W_0+D_1*H_1*W_1+D_2*H_2*W_2, ..., D_0*H_0*W_0+D_1*H_1*W_1+...+D_{L-1}*H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} D_l \cdot H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
if reference_points.shape[-1] == 3:
offset_normalizer = torch.stack([input_spatial_shapes[..., 0], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (D_, H_, W_) in enumerate(spatial_shapes):
ref_d, ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, D_ - 0.5, D_, dtype=torch.float32, device=device),
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_d = ref_d.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * D_)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 2] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * W_)
ref = torch.stack((ref_d, ref_x, ref_y), -1) # D W H
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
output = src
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
return output
class DeformableTransformer(nn.Module):
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4):
super().__init__()
self.d_model = d_model
self.nhead = nhead
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points)
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, D, H, W = mask.shape
valid_D = torch.sum(~mask[:, :, 0, 0], 1)
valid_H = torch.sum(~mask[:, 0, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, 0, :], 1)
valid_ratio_d = valid_D.float() / D
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_d, valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, masks, pos_embeds):
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, d, h, w = src.shape
spatial_shape = (d, h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
return memory
class Conv3d_wd(nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, bias=False):
super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)
weight = weight - weight_mean
# std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5
std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)
weight = weight / std.expand_as(weight)
return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1,
bias=False, weight_std=False):
"3x3x3 convolution with padding"
if weight_std:
return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
else:
return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
def Norm_layer(norm_cfg, inplanes):
if norm_cfg == 'BN':
out = nn.BatchNorm3d(inplanes)
elif norm_cfg == 'SyncBN':
out = nn.SyncBatchNorm(inplanes)
elif norm_cfg == 'GN':
out = nn.GroupNorm(16, inplanes)
elif norm_cfg == 'IN':
out = nn.InstanceNorm3d(inplanes, affine=True)
return out
def Activation_layer(activation_cfg, inplace=True):
if activation_cfg == 'ReLU':
out = nn.ReLU(inplace=inplace)
elif activation_cfg == 'LeakyReLU':
out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace)
return out
class ResBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False):
super(ResBlock, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std)
self.norm1 = Norm_layer(norm_cfg, planes)
self.nonlin = Activation_layer(activation_cfg, inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.nonlin(out)
return out
class Backbone(nn.Module):
def __init__(self, depth, in_channels=1, norm_cfg='BN', activation_cfg='ReLU', weight_std=False):
super(Backbone, self).__init__()
self.arch_settings = {
9: (ResBlock, (3, 3, 2))
}
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
self.depth = depth
block, layers = self.arch_settings[depth]
self.inplanes = 64
self.conv1 = conv3x3x3(in_channels, 64, kernel_size=7, stride=(1, 2, 2), padding=3, bias=False, weight_std=weight_std)
self.norm1 = Norm_layer(norm_cfg, 64)
self.nonlin = Activation_layer(activation_cfg, inplace=True)
self.layer1 = self._make_layer(block, 192, layers[0], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
self.layer2 = self._make_layer(block, 384, layers[1], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
self.layer3 = self._make_layer(block, 384, layers[2], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
self.layers = []
for m in self.modules():
if isinstance(m, (nn.Conv3d, Conv3d_wd)):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=(1, 1, 1), norm_cfg='BN', activation_cfg='ReLU', weight_std=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv3x3x3(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False, weight_std=weight_std), Norm_layer(norm_cfg, planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, stride=stride, downsample=downsample, weight_std=weight_std))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, weight_std=weight_std))
return nn.Sequential(*layers)
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv3d, Conv3d_wd)):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = []
x = self.conv1(x)
x = self.norm1(x)
x = self.nonlin(x)
out.append(x)
x = self.layer1(x)
out.append(x)
x = self.layer2(x)
out.append(x)
x = self.layer3(x)
out.append(x)
return out
class Conv3dBlock(nn.Module):
def __init__(self, in_channels, out_channels, norm_cfg, activation_cfg, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), bias=False, weight_std=False):
super(Conv3dBlock, self).__init__()
self.conv = conv3x3x3(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, weight_std=weight_std)
self.norm = Norm_layer(norm_cfg, out_channels)
self.nonlin = Activation_layer(activation_cfg, inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.nonlin(x)
return x
class ResBlock_(nn.Module):
def __init__(self, inplanes, planes, norm_cfg, activation_cfg, weight_std=False):
super(ResBlock_, self).__init__()
self.resconv1 = Conv3dBlock(inplanes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std)
self.resconv2 = Conv3dBlock(planes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std)
def forward(self, x):
residual = x
out = self.resconv1(x)
out = self.resconv2(out)
out = out + residual
return out
class U_ResTran3D(nn.Module):
def __init__(self, in_channels, num_classes, norm_cfg='BN', activation_cfg='ReLU', weight_std=False):
super(U_ResTran3D, self).__init__()
self.MODEL_NUM_CLASSES = num_classes
self.upsamplex2 = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)
self.transposeconv_stage2 = nn.ConvTranspose3d(384, 384, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
self.transposeconv_stage1 = nn.ConvTranspose3d(384, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
self.transposeconv_stage0 = nn.ConvTranspose3d(192, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
self.stage2_de = ResBlock_(384, 384, norm_cfg, activation_cfg, weight_std=weight_std)
self.stage1_de = ResBlock_(192, 192, norm_cfg, activation_cfg, weight_std=weight_std)
self.stage0_de = ResBlock_(64, 64, norm_cfg, activation_cfg, weight_std=weight_std)
# self.ds2_cls_conv = nn.Conv3d(384, self.MODEL_NUM_CLASSES, kernel_size=1)
# self.ds1_cls_conv = nn.Conv3d(192, self.MODEL_NUM_CLASSES, kernel_size=1)
# self.ds0_cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1)
self.cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1)
for m in self.modules():
if isinstance(m, (nn.Conv3d, Conv3d_wd, nn.ConvTranspose3d)):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm, nn.InstanceNorm3d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
self.backbone = Backbone(depth=9, in_channels=in_channels, norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
# total = sum([param.nelement() for param in self.backbone.parameters()])
# print(' + Number of Backbone Params: %.2f(e6)' % (total / 1e6))
self.position_embed = build_position_encoding(mode='v2', hidden_dim=384)
self.encoder_Detrans = DeformableTransformer(d_model=384, dim_feedforward=1536, dropout=0.1, activation='gelu',
num_feature_levels=2, nhead=6, num_encoder_layers=6,
enc_n_points=4)
# total = sum([param.nelement() for param in self.encoder_Detrans.parameters()])
# print(' + Number of Transformer Params: %.2f(e6)' % (total / 1e6))
def posi_mask(self, x):
x_fea = []
x_posemb = []
masks = []
for lvl, fea in enumerate(x):
if lvl > 1:
x_fea.append(fea)
x_posemb.append(self.position_embed(fea))
masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool).cuda())
# masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool))
return x_fea, masks, x_posemb
def forward(self, inputs):
# # %%%%%%%%%%%%% CoTr
x_convs = self.backbone(inputs)
x_fea, masks, x_posemb = self.posi_mask(x_convs)
x_trans = self.encoder_Detrans(x_fea, masks, x_posemb)
# # Single_scale
# # x = self.transposeconv_stage2(x_trans.transpose(-1, -2).view(x_convs[-1].shape))
# # skip2 = x_convs[-2]
# Multi-scale
x = self.transposeconv_stage2(x_trans[:, x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]::].transpose(-1, -2).view(x_convs[-1].shape)) # x_trans length: 12*24*24+6*12*12=7776
skip2 = x_trans[:, 0:x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]].transpose(-1, -2).view(x_convs[-2].shape)
x = x + skip2
x = self.stage2_de(x)
# ds2 = self.ds2_cls_conv(x)
x = self.transposeconv_stage1(x)
skip1 = x_convs[-3]
x = x + skip1
x = self.stage1_de(x)
# ds1 = self.ds1_cls_conv(x)
x = self.transposeconv_stage0(x)
skip0 = x_convs[-4]
x = x + skip0
x = self.stage0_de(x)
# ds0 = self.ds0_cls_conv(x)
result = self.upsamplex2(x)
result = self.cls_conv(result)
return result
def cotr(in_channels, num_classes):
model = U_ResTran3D(in_channels, num_classes)
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
#
# mask = torch.ones(2, 64, 64, 64).long()
# model = cotr(1, 10)
# model.train()
# input = torch.rand(2, 1, 64, 64, 64)
# output = model(input)
# loss_train = criterion(output, mask)
# output = output.data.cpu().numpy()
# loss_train.backward()
# print(output.shape)
# print(loss_train)
================================================
FILE: models/networks_3d/dmfnet.py
================================================
import torch.nn as nn
import torch.nn.functional as F
import torch
# from loss.loss_function import segmentation_loss
def normalization(planes, norm='bn'):
if norm == 'bn':
m = nn.BatchNorm3d(planes)
elif norm == 'gn':
m = nn.GroupNorm(4, planes)
elif norm == 'in':
m = nn.InstanceNorm3d(planes)
else:
raise ValueError('normalization type {} is not supported'.format(norm))
return m
class Conv3d_Block(nn.Module):
def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=None,norm=None):
super(Conv3d_Block, self).__init__()
if padding == None:
padding = (kernel_size - 1) // 2
self.bn = normalization(num_in,norm=norm)
self.act_fn = nn.ReLU(inplace=True)
self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding,stride=stride, groups=g, bias=False)
def forward(self, x): # BN + Relu + Conv
h = self.act_fn(self.bn(x))
h = self.conv(h)
return h
class DilatedConv3DBlock(nn.Module):
def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1, d=(1,1,1), norm=None):
super(DilatedConv3DBlock, self).__init__()
assert isinstance(kernel_size,tuple) and isinstance(d,tuple)
padding = tuple(
[(ks-1)//2 *dd for ks, dd in zip(kernel_size, d)]
)
self.bn = normalization(num_in, norm=norm)
self.act_fn = nn.ReLU(inplace=True)
self.conv = nn.Conv3d(num_in,num_out,kernel_size=kernel_size,padding=padding,stride=stride,groups=g,dilation=d,bias=False)
def forward(self, x):
h = self.act_fn(self.bn(x))
h = self.conv(h)
return h
class MFunit(nn.Module):
def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None):
""" The second 3x3x1 group conv is replaced by 3x3x3.
:param num_in: number of input channels
:param num_out: number of output channels
:param g: groups of group conv.
:param stride: 1 or 2
:param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv
:param norm: Batch Normalization
"""
super(MFunit, self).__init__()
num_mid = num_in if num_in <= num_out else num_out
self.conv1x1x1_in1 = Conv3d_Block(num_in,num_in//4,kernel_size=1,stride=1,norm=norm)
self.conv1x1x1_in2 = Conv3d_Block(num_in//4,num_mid,kernel_size=1,stride=1,norm=norm)
self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid,num_out,kernel_size=(3,3,3),stride=stride,g=g,d=(d[0],d[0],d[0]),norm=norm) # dilated
self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(3,3,1),stride=1,g=g,d=(d[1],d[1],1),norm=norm)
# self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm)
# skip connection
if num_in != num_out or stride != 1:
if stride == 1:
self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0,norm=norm)
if stride == 2:
# if MF block with stride=2, 2x2x2
self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2,padding=0, norm=norm) # params
def forward(self, x):
x1 = self.conv1x1x1_in1(x)
x2 = self.conv1x1x1_in2(x1)
x3 = self.conv3x3x3_m1(x2)
x4 = self.conv3x3x3_m2(x3)
shortcut = x
if hasattr(self,'conv1x1x1_shortcut'):
shortcut = self.conv1x1x1_shortcut(shortcut)
if hasattr(self,'conv2x2x2_shortcut'):
shortcut = self.conv2x2x2_shortcut(shortcut)
return x4 + shortcut
class DMFUnit(nn.Module):
# weighred add
def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=None):
super(DMFUnit, self).__init__()
self.weight1 = nn.Parameter(torch.ones(1))
self.weight2 = nn.Parameter(torch.ones(1))
self.weight3 = nn.Parameter(torch.ones(1))
num_mid = num_in if num_in <= num_out else num_out
self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)
self.conv1x1x1_in2 = Conv3d_Block(num_in // 4,num_mid,kernel_size=1, stride=1, norm=norm)
self.conv3x3x3_m1 = nn.ModuleList()
if dilation == None:
dilation = [1,2,3]
for i in range(3):
self.conv3x3x3_m1.append(
DilatedConv3DBlock(num_mid,num_out, kernel_size=(3, 3, 3), stride=stride, g=g, d=(dilation[i],dilation[i], dilation[i]),norm=norm)
)
# It has not Dilated operation
self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
# self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
# skip connection
if num_in != num_out or stride != 1:
if stride == 1:
self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)
if stride == 2:
self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm)
def forward(self, x):
x1 = self.conv1x1x1_in1(x)
x2 = self.conv1x1x1_in2(x1)
x3 = self.weight1*self.conv3x3x3_m1[0](x2) + self.weight2*self.conv3x3x3_m1[1](x2) + self.weight3*self.conv3x3x3_m1[2](x2)
x4 = self.conv3x3x3_m2(x3)
shortcut = x
if hasattr(self, 'conv1x1x1_shortcut'):
shortcut = self.conv1x1x1_shortcut(shortcut)
if hasattr(self, 'conv2x2x2_shortcut'):
shortcut = self.conv2x2x2_shortcut(shortcut)
return x4 + shortcut
class MFNet(nn.Module): #
# [96] Flops: 13.361G & Params: 1.81M
# [112] Flops: 16.759G & Params: 2.46M
# [128] Flops: 20.611G & Params: 3.19M
def __init__(self,in_channels, num_classes, n=32, channels=128, groups=16, norm='bn'):
super(MFNet, self).__init__()
# Entry flow
self.encoder_block1 = nn.Conv3d(in_channels, n, kernel_size=3, padding=1, stride=2, bias=False)# H//2
self.encoder_block2 = nn.Sequential(
MFunit(n, channels, g=groups, stride=2, norm=norm),# H//4 down
MFunit(channels, channels, g=groups, stride=1, norm=norm),
MFunit(channels, channels, g=groups, stride=1, norm=norm)
)
#
self.encoder_block3 = nn.Sequential(
MFunit(channels, channels*2, g=groups, stride=2, norm=norm), # H//8
MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm),
MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm)
)
self.encoder_block4 = nn.Sequential(# H//8,channels*4
MFunit(channels*2, channels*3, g=groups, stride=2, norm=norm), # H//16
MFunit(channels*3, channels*3, g=groups, stride=1, norm=norm),
MFunit(channels*3, channels*2, g=groups, stride=1, norm=norm),
)
self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8
self.decoder_block1 = MFunit(channels*2+channels*2, channels*2, g=groups, stride=1, norm=norm)
self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4
self.decoder_block2 = MFunit(channels*2 + channels, channels, g=groups, stride=1, norm=norm)
self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2
self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm)
self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H
self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0,stride=1,bias=False)
# Initialization
for m in self.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) #
elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# Encoder
x1 = self.encoder_block1(x)# H//2 down
x2 = self.encoder_block2(x1)# H//4 down
x3 = self.encoder_block3(x2)# H//8 down
x4 = self.encoder_block4(x3) # H//16
# Decoder
y1 = self.upsample1(x4)# H//8
y1 = torch.cat([x3,y1],dim=1)
y1 = self.decoder_block1(y1)
y2 = self.upsample2(y1)# H//4
y2 = torch.cat([x2,y2],dim=1)
y2 = self.decoder_block2(y2)
y3 = self.upsample3(y2)# H//2
y3 = torch.cat([x1,y3],dim=1)
y3 = self.decoder_block3(y3)
y4 = self.upsample4(y3)
y4 = self.seg(y4)
return y4
class DMFNet(MFNet): # softmax
# [128] Flops: 27.045G & Params: 3.88M
def __init__(self,in_channels, num_classes, n=32,channels=128, groups=16,norm='bn'):
super(DMFNet, self).__init__(in_channels, num_classes, n, channels, groups, norm)
self.encoder_block2 = nn.Sequential(
DMFUnit(n, channels, g=groups, stride=2, norm=norm,dilation=[1,2,3]),# H//4 down
DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]), # Dilated Conv 3
DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3])
)
self.encoder_block3 = nn.Sequential(
DMFUnit(channels, channels*2, g=groups, stride=2, norm=norm,dilation=[1,2,3]), # H//8
DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]),# Dilated Conv 3
DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3])
)
def dmfnet(in_channels, num_classes):
model = DMFNet(in_channels, num_classes)
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 64, 64, 64).long()
#
# model = dmfnet(1, 10)
# model.train()
# input = torch.rand(2, 1, 64, 64, 64)
# output = model(input)
#
# loss_train = criterion(output, mask)
# loss_train.backward()
#
# output = output.data.cpu().numpy()
# print(output.shape)
# print(loss_train)
================================================
FILE: models/networks_3d/espnet3d.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# from loss.loss_function import segmentation_loss
from warnings import simplefilter
simplefilter(action='ignore', category=UserWarning)
class CBR(nn.Module):
def __init__(self, nIn, nOut, kSize, stride=1):
super().__init__()
padding = int((kSize - 1) / 2)
self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)
self.act = nn.ReLU(inplace=True)
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
output = self.act(output)
return output
class CB(nn.Module):
def __init__(self, nIn, nOut, kSize, stride=1):
super().__init__()
padding = int((kSize - 1) / 2)
self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
return output
class C(nn.Module):
def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
super().__init__()
padding = int((kSize - 1) / 2)
self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups)
def forward(self, input):
output = self.conv(input)
return output
class DownSamplerA(nn.Module):
def __init__(self, nIn, nOut):
super().__init__()
self.conv = CBR(nIn, nOut, 3, 2)
def forward(self, input):
output = self.conv(input)
return output
class DownSamplerB(nn.Module):
def __init__(self, nIn, nOut):
super().__init__()
k = 4
n = int(nOut/k)
n1 = nOut - (k-1)*n
self.c1 = nn.Sequential(CBR(nIn, n, 1, 1), C(n, n, 3, 2))
self.d1 = CDilated(n, n1, 3, 1, 1)
self.d2 = CDilated(n, n, 3, 1, 2)
self.d4 = CDilated(n, n, 3, 1, 3)
self.d8 = CDilated(n, n, 3, 1, 4)
self.bn = BR(nOut)
def forward(self, input):
output1 = self.c1(input)
d1 = self.d1(output1)
d2 = self.d2(output1)
d4 = self.d4(output1)
d8 = self.d8(output1)
add1 = d2
add2 = add1 + d4
add3 = add2 + d8
combine = torch.cat([d1, add1, add2, add3],1)
if input.size() == combine.size():
combine = input + combine
output = self.bn(combine)
return output
class BR(nn.Module):
def __init__(self, nOut):
super().__init__()
self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)
self.act = nn.ReLU(inplace=True) # nn.PReLU(nOut)
def forward(self, input):
output = self.bn(input)
output = self.act(output)
return output
class CDilated(nn.Module):
def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
super().__init__()
padding = int((kSize - 1) / 2) * d
self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d, groups=groups)
#self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)
def forward(self, input):
return self.conv(input)
#return self.bn(output)
class InputProjectionA(nn.Module):
'''
This class projects the input image to the same spatial dimensions as the feature map.
For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then
this class will generate an output of 56x56x3
'''
def __init__(self, samplingTimes):
'''
:param samplingTimes: The rate at which you want to down-sample the image
'''
super().__init__()
self.pool = nn.ModuleList()
for i in range(0, samplingTimes):
# pyramid-based approach for down-sampling
self.pool.append(nn.AvgPool3d(3, stride=2, padding=1))
def forward(self, input):
'''
:param input: Input RGB Image
:return: down-sampled image (pyramid-based approach)
'''
for pool in self.pool:
input = pool(input)
return input
class DilatedParllelResidualBlockB1(nn.Module): # with k=4
def __init__(self, nIn, nOut, stride=1):
super().__init__()
k = 4
n = int(nOut / k)
n1 = nOut - (k - 1) * n
self.c1 = CBR(nIn, n, 1, 1)
self.d1 = CDilated(n, n1, 3, stride, 1)
self.d2 = CDilated(n, n, 3, stride, 1)
self.d4 = CDilated(n, n, 3, stride, 2)
self.d8 = CDilated(n, n, 3, stride, 2)
self.bn = nn.BatchNorm3d(nOut)
def forward(self, input):
output1 = self.c1(input)
d1 = self.d1(output1)
d2 = self.d2(output1)
d4 = self.d4(output1)
d8 = self.d8(output1)
add1 = d2
add2 = add1 + d4
add3 = add2 + d8
combine = self.bn(torch.cat([d1, add1, add2, add3], 1))
if input.size() == combine.size():
combine = input + combine
output = F.relu(combine, inplace=True)
return output
class ASPBlock(nn.Module): # with k=4
def __init__(self, nIn, nOut, stride=1):
super().__init__()
self.d1 = CB(nIn, nOut, 3, 1)
self.d2 = CB(nIn, nOut, 5, 1)
self.d4 = CB(nIn, nOut, 7, 1)
self.d8 = CB(nIn, nOut, 9, 1)
self.act = nn.ReLU(inplace=True)
def forward(self, input):
d1 = self.d1(input)
d2 = self.d2(input)
d3 = self.d4(input)
d4 = self.d8(input)
combine = d1 + d2 + d3 + d4
if input.size() == combine.size():
combine = input + combine
output = self.act(combine)
return output
class UpSampler(nn.Module):
'''
Up-sample the feature maps by 2
'''
def __init__(self, nIn, nOut):
super().__init__()
self.up = CBR(nIn, nOut, 3, 1)
def forward(self, inp):
return F.upsample(self.up(inp), mode='trilinear', scale_factor=2, align_corners=True)
class PSPDec(nn.Module):
'''
Inspired or Adapted from Pyramid Scene Network paper
'''
def __init__(self, nIn, nOut, downSize):
super().__init__()
self.scale = downSize
self.features = CBR(nIn, nOut, 3, 1)
def forward(self, x):
assert x.dim() == 5
inp_size = x.size()
out_dim1, out_dim2, out_dim3 = int(inp_size[2] * self.scale), int(inp_size[3] * self.scale), int(inp_size[4] * self.scale)
x_down = F.adaptive_avg_pool3d(x, output_size=(out_dim1, out_dim2, out_dim3))
return F.upsample(self.features(x_down), size=(inp_size[2], inp_size[3], inp_size[4]), mode='trilinear', align_corners=True)
class ESPNet(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.input1 = InputProjectionA(1)
self.input2 = InputProjectionA(1)
initial = 16 # feature maps at level 1
config = [32, 128, 256, 256] # feature maps at level 2 and onwards
reps = [2, 2, 3]
### ENCODER
# all dimensions are listed with respect to an input of size 4 x 128 x 128 x 128
self.level0 = CBR(in_channels, initial, 7, 2) # initial x 64 x 64 x64
self.level1 = nn.ModuleList()
for i in range(reps[0]):
if i==0:
self.level1.append(DilatedParllelResidualBlockB1(initial, config[0])) # config[0] x 64 x 64 x64
else:
self.level1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x64
# downsample the feature maps
self.level2 = DilatedParllelResidualBlockB1(config[0], config[1], stride=2) # config[1] x 32 x 32 x 32
self.level_2 = nn.ModuleList()
for i in range(0, reps[1]):
self.level_2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32
# downsample the feature maps
self.level3_0 = DilatedParllelResidualBlockB1(config[1], config[2], stride=2) # config[2] x 16 x 16 x 16
self.level_3 = nn.ModuleList()
for i in range(0, reps[2]):
self.level_3.append(DilatedParllelResidualBlockB1(config[2], config[2])) # config[2] x 16 x 16 x 16
### DECODER
# upsample the feature maps
self.up_l3_l2 = UpSampler(config[2], config[1]) # config[1] x 32 x 32 x 32
# Note the 2 in below line. You need this because you are concatenating feature maps from encoder
# with upsampled feature maps
self.merge_l2 = DilatedParllelResidualBlockB1(2 * config[1], config[1]) # config[1] x 32 x 32 x 32
self.dec_l2 = nn.ModuleList()
for i in range(0, reps[0]):
self.dec_l2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32
self.up_l2_l1 = UpSampler(config[1], config[0]) # config[0] x 64 x 64 x 64
# Note the 2 in below line. You need this because you are concatenating feature maps from encoder
# with upsampled feature maps
self.merge_l1 = DilatedParllelResidualBlockB1(2*config[0], config[0]) # config[0] x 64 x 64 x 64
self.dec_l1 = nn.ModuleList()
for i in range(0, reps[0]):
self.dec_l1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x 64
self.dec_l1.append(CBR(config[0], num_classes, 3, 1)) # classes x 64 x 64 x 64
# We use ESP block without reduction step because the number of input feature maps are very small (i.e. 4 in
# our case)
self.dec_l1.append(ASPBlock(num_classes, num_classes))
# Using PSP module to learn the representations at different scales
self.pspModules = nn.ModuleList()
scales = [0.2, 0.4, 0.6, 0.8]
for sc in scales:
self.pspModules.append(PSPDec(num_classes, num_classes, sc))
# Classifier
self.classifier = self.classifier = nn.Sequential(
CBR((len(scales) + 1) * num_classes, num_classes, 3, 1),
ASPBlock(num_classes, num_classes), # classes x 64 x 64 x 64
nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True), # classes x 128 x 128 x 128
CBR(num_classes, num_classes, 7, 1), # classes x 128 x 128 x 128
C(num_classes, num_classes, 1, 1) # classes x 128 x 128 x 128
)
#
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if isinstance(m, nn.ConvTranspose3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, input1, inp_res=(128, 128, 128), inpSt2=False):
dim0 = input1.size(2)
dim1 = input1.size(3)
dim2 = input1.size(4)
if self.training or inp_res is None:
# input resolution should be divisible by 8
inp_res = (math.ceil(dim0 / 8) * 8, math.ceil(dim1 / 8) * 8,
math.ceil(dim2 / 8) * 8)
if inp_res:
input1 = F.adaptive_avg_pool3d(input1, output_size=inp_res)
out_l0 = self.level0(input1)
for i, layer in enumerate(self.level1): #64
if i == 0:
out_l1 = layer(out_l0)
else:
out_l1 = layer(out_l1)
out_l2_down = self.level2(out_l1) #32
for i, layer in enumerate(self.level_2):
if i == 0:
out_l2 = layer(out_l2_down)
else:
out_l2 = layer(out_l2)
del out_l2_down
out_l3_down = self.level3_0(out_l2) #16
for i, layer in enumerate(self.level_3):
if i == 0:
out_l3 = layer(out_l3_down)
else:
out_l3 = layer(out_l3)
del out_l3_down
dec_l3_l2 = self.up_l3_l2(out_l3)
merge_l2 = self.merge_l2(torch.cat([dec_l3_l2, out_l2], 1))
for i, layer in enumerate(self.dec_l2):
if i == 0:
dec_l2 = layer(merge_l2)
else:
dec_l2 = layer(dec_l2)
dec_l2_l1 = self.up_l2_l1(dec_l2)
merge_l1 = self.merge_l1(torch.cat([dec_l2_l1, out_l1], 1))
for i, layer in enumerate(self.dec_l1):
if i == 0:
dec_l1 = layer(merge_l1)
else:
dec_l1 = layer(dec_l1)
psp_outs = dec_l1.clone()
for layer in self.pspModules:
out_psp = layer(dec_l1)
psp_outs = torch.cat([psp_outs, out_psp], 1)
decoded = self.classifier(psp_outs)
return F.upsample(decoded, size=(dim0, dim1, dim2), mode='trilinear', align_corners=True)
def espnet3d(in_channels, num_classes):
model = ESPNet(in_channels, num_classes)
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
#
# mask = torch.ones(2, 96, 48, 96).long()
# model = espnet3d(1, 10)
# model.train()
# input = torch.rand(2, 1, 96, 48, 96)
# output = model(input)
# loss_train = criterion(output, mask)
# output = output.data.cpu().numpy()
# loss_train.backward()
# print(output.shape)
# print(loss_train)
================================================
FILE: models/networks_3d/res_unet3d.py
================================================
import torch
import torch.nn as nn
import os
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class UNet(nn.Module):
"""
Implementations based on the Unet3D paper: https://arxiv.org/pdf/1706.00120.pdf
"""
def __init__(self, in_channels, n_classes, base_n_filter=8):
super(UNet, self).__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.base_n_filter = base_n_filter
self.lrelu = nn.LeakyReLU()
self.dropout3d = nn.Dropout3d(p=0.6)
self.upsacle = nn.Upsample(scale_factor=2, mode='nearest')
self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)
self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False)
self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2)
self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2)
self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False)
self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4)
self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4)
self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False)
self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8)
self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8)
self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False)
self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16)
self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 8)
self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False)
self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8)
self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16)
self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 4)
self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8)
self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0,
bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 2)
self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4)
self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter)
self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2)
self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
def conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def norm_lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
# should be feat_in*2 or feat_in
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def forward(self, x):
# Level 1 context pathway
out = self.conv3d_c1_1(x)
residual_1 = out
out = self.lrelu(out)
out = self.conv3d_c1_2(out)
out = self.dropout3d(out)
out = self.lrelu_conv_c1(out)
# Element Wise Summation
out += residual_1
context_1 = self.lrelu(out)
out = self.inorm3d_c1(out)
out = self.lrelu(out)
# Level 2 context pathway
out = self.conv3d_c2(out)
residual_2 = out
out = self.norm_lrelu_conv_c2(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c2(out)
out += residual_2
out = self.inorm3d_c2(out)
out = self.lrelu(out)
context_2 = out
# Level 3 context pathway
out = self.conv3d_c3(out)
residual_3 = out
out = self.norm_lrelu_conv_c3(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c3(out)
out += residual_3
out = self.inorm3d_c3(out)
out = self.lrelu(out)
context_3 = out
# Level 4 context pathway
out = self.conv3d_c4(out)
residual_4 = out
out = self.norm_lrelu_conv_c4(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c4(out)
out += residual_4
out = self.inorm3d_c4(out)
out = self.lrelu(out)
context_4 = out
# Level 5
out = self.conv3d_c5(out)
residual_5 = out
out = self.norm_lrelu_conv_c5(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c5(out)
out += residual_5
out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)
out = self.conv3d_l0(out)
out = self.inorm3d_l0(out)
out = self.lrelu(out)
# Level 1 localization pathway
out = torch.cat([out, context_4], dim=1)
out = self.conv_norm_lrelu_l1(out)
out = self.conv3d_l1(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)
# Level 2 localization pathway
# print(out.shape)
# print(context_3.shape)
out = torch.cat([out, context_3], dim=1)
out = self.conv_norm_lrelu_l2(out)
ds2 = out
out = self.conv3d_l2(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)
# Level 3 localization pathway
out = torch.cat([out, context_2], dim=1)
out = self.conv_norm_lrelu_l3(out)
ds3 = out
out = self.conv3d_l3(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)
# Level 4 localization pathway
out = torch.cat([out, context_1], dim=1)
out = self.conv_norm_lrelu_l4(out)
out_pred = self.conv3d_l4(out)
ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)
ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)
out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
seg_layer = out
return seg_layer
def res_unet3d(in_channels, num_classes):
model = UNet(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = res_unet3d(1,10)
# model.eval()
# input = torch.rand(2, 1, 128, 128, 128)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/transbts.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from loss.loss_function import segmentation_loss
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def normalization(planes, norm='gn'):
if norm == 'bn':
m = nn.BatchNorm3d(planes)
elif norm == 'gn':
m = nn.GroupNorm(8, planes)
elif norm == 'in':
m = nn.InstanceNorm3d(planes)
else:
raise ValueError('normalization type {} is not supported'.format(norm))
return m
class InitConv(nn.Module):
def __init__(self, in_channels=4, out_channels=16, dropout=0.2):
super(InitConv, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.dropout = dropout
def forward(self, x):
y = self.conv(x)
y = F.dropout3d(y, self.dropout)
return y
class EnBlock(nn.Module):
def __init__(self, in_channels, norm='gn'):
super(EnBlock, self).__init__()
self.bn1 = normalization(in_channels, norm=norm)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = normalization(in_channels, norm=norm)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.bn1(x)
x1 = self.relu1(x1)
x1 = self.conv1(x1)
y = self.bn2(x1)
y = self.relu2(y)
y = self.conv2(y)
y = y + x
return y
class EnDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(EnDown, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
y = self.conv(x)
return y
class Unet(nn.Module):
def __init__(self, in_channels=4, base_channels=16):
super(Unet, self).__init__()
self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
self.EnBlock1 = EnBlock(in_channels=base_channels)
self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)
self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)
self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)
self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)
def forward(self, x):
x = self.InitConv(x) # (1, 16, 128, 128, 128)
x1_1 = self.EnBlock1(x)
x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64)
x2_1 = self.EnBlock2_1(x1_2)
x2_1 = self.EnBlock2_2(x2_1)
x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32)
x3_1 = self.EnBlock3_1(x2_2)
x3_1 = self.EnBlock3_2(x3_1)
x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16)
x4_1 = self.EnBlock4_1(x3_2)
x4_2 = self.EnBlock4_2(x4_1)
x4_3 = self.EnBlock4_3(x4_2)
output = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16)
return x1_1, x2_1, x3_1, output
class FixedPositionalEncoding(nn.Module):
def __init__(self, embedding_dim, max_length=512):
super(FixedPositionalEncoding, self).__init__()
pe = torch.zeros(max_length, embedding_dim)
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return x
class LearnedPositionalEncoding(nn.Module):
def __init__(self, max_position_embeddings, embedding_dim):
super(LearnedPositionalEncoding, self).__init__()
self.position_embeddings = nn.Parameter(torch.zeros(1, max_position_embeddings, embedding_dim)) #8x
def forward(self, x):
position_embeddings = self.position_embeddings
return x + position_embeddings
class IntermediateSequential(nn.Sequential):
def __init__(self, *args, return_intermediate=True):
super().__init__(*args)
self.return_intermediate = return_intermediate
def forward(self, input):
if not self.return_intermediate:
return super().forward(input)
intermediate_outputs = {}
output = input
for name, module in self.named_children():
output = intermediate_outputs[name] = module(output)
return output, intermediate_outputs
class SelfAttention(nn.Module):
def __init__(
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
):
super().__init__()
self.num_heads = heads
head_dim = dim // heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(dropout_rate)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout_rate)
def forward(self, x):
B, N, C = x.shape
qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
q, k, v = (qkv[0], qkv[1], qkv[2]) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class PreNormDrop(nn.Module):
def __init__(self, dim, dropout_rate, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(p=dropout_rate)
self.fn = fn
def forward(self, x):
return self.dropout(self.fn(self.norm(x)))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_rate):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_dim, dim),
nn.Dropout(p=dropout_rate),
)
def forward(self, x):
return self.net(x)
class TransformerModel(nn.Module):
def __init__(self,dim,depth,heads,mlp_dim,dropout_rate=0.1,attn_dropout_rate=0.1):
super().__init__()
layers = []
for _ in range(depth):
layers.extend([
Residual(PreNormDrop(dim,dropout_rate,SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))])
# dim = dim / 2
self.net = IntermediateSequential(*layers)
def forward(self, x):
return self.net(x)
class TransformerBTS(nn.Module):
def __init__(
self,
img_dim,
patch_dim,
num_channels,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
dropout_rate=0.0,
attn_dropout_rate=0.0,
conv_patch_representation=True,
positional_encoding_type="learned",
):
super(TransformerBTS, self).__init__()
assert embedding_dim % num_heads == 0
assert img_dim[0] % patch_dim == 0
assert img_dim[1] % patch_dim == 0
assert img_dim[2] % patch_dim == 0
self.img_dim = img_dim
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.patch_dim = patch_dim
self.num_channels = num_channels
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.conv_patch_representation = conv_patch_representation
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
self.seq_length = self.num_patches
self.flatten_dim = 128 * num_channels
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
if positional_encoding_type == "learned":
self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim)
elif positional_encoding_type == "fixed":
self.position_encoding = FixedPositionalEncoding(self.embedding_dim)
self.pe_dropout = nn.Dropout(p=self.dropout_rate)
self.transformer = TransformerModel(embedding_dim,num_layers,num_heads,hidden_dim,self.dropout_rate,self.attn_dropout_rate)
self.pre_head_ln = nn.LayerNorm(embedding_dim)
if self.conv_patch_representation:
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
self.Unet = Unet(in_channels=num_channels, base_channels=16)
self.bn = nn.BatchNorm3d(128)
self.relu = nn.ReLU(inplace=True)
def encode(self, x):
if self.conv_patch_representation:
# combine embedding with conv patch distribution
x1_1, x2_1, x3_1, x = self.Unet(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv_x(x)
x = x.permute(0, 2, 3, 4, 1).contiguous()
x = x.view(x.size(0), -1, self.embedding_dim)
else:
x = self.Unet(x)
x = self.bn(x)
x = self.relu(x)
x = (
x.unfold(2, 2, 2)
.unfold(3, 2, 2)
.unfold(4, 2, 2)
.contiguous()
)
x = x.view(x.size(0), x.size(1), -1, 8)
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(x.size(0), -1, self.flatten_dim)
x = self.linear_encoding(x)
x = self.position_encoding(x)
x = self.pe_dropout(x)
# apply transformer
x, intmd_x = self.transformer(x)
x = self.pre_head_ln(x)
return x1_1, x2_1, x3_1, x, intmd_x
def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):
x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x)
decoder_output = self.decode(x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers)
if auxillary_output_layers is not None:
auxillary_outputs = {}
for i in auxillary_output_layers:
val = str(2 * i - 1)
_key = 'Z' + str(i)
auxillary_outputs[_key] = intmd_encoder_outputs[val]
return decoder_output
return decoder_output
# def _get_padding(self, padding_type, kernel_size):
# assert padding_type in ['SAME', 'VALID']
# if padding_type == 'SAME':
# _list = [(k - 1) // 2 for k in kernel_size]
# return tuple(_list)
# return tuple(0 for _ in kernel_size)
def _reshape_output(self, x):
x = x.view(
x.size(0),
int(self.img_dim[0] / self.patch_dim),
int(self.img_dim[1] / self.patch_dim),
int(self.img_dim[2] / self.patch_dim),
self.embedding_dim,
)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
class BTS(TransformerBTS):
def __init__(self,
in_channels,
num_classes,
img_shape=(128, 128, 128),
patch_dim=8,
embedding_dim=512,
num_heads=8,
num_layers=4,
hidden_dim=4096,
dropout_rate=0.1,
attn_dropout_rate=0.1,
conv_patch_representation=True,
positional_encoding_type="learned"):
super(BTS, self).__init__(
img_dim=img_shape,
patch_dim=patch_dim,
num_channels=in_channels,
embedding_dim=embedding_dim,
num_heads=num_heads,
num_layers=num_layers,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
attn_dropout_rate=attn_dropout_rate,
conv_patch_representation=conv_patch_representation,
positional_encoding_type=positional_encoding_type,
)
self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim)
self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4)
self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8)
self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8)
self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16)
self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16)
self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32)
self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32)
self.endconv = nn.Conv3d(self.embedding_dim // 32, num_classes, kernel_size=1)
def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):
assert intmd_layers is not None, "pass the intermediate layers for MLA"
encoder_outputs = {}
all_keys = []
for i in intmd_layers:
val = str(2 * i - 1)
_key = 'Z' + str(i)
all_keys.append(_key)
encoder_outputs[_key] = intmd_x[val]
all_keys.reverse()
x8 = encoder_outputs[all_keys[0]]
x8 = self._reshape_output(x8)
x8 = self.Enblock8_1(x8)
x8 = self.Enblock8_2(x8)
y4 = self.DeUp4(x8, x3_1) # (1, 64, 32, 32, 32)
y4 = self.DeBlock4(y4)
y3 = self.DeUp3(y4, x2_1) # (1, 32, 64, 64, 64)
y3 = self.DeBlock3(y3)
y2 = self.DeUp2(y3, x1_1) # (1, 16, 128, 128, 128)
y2 = self.DeBlock2(y2)
y = self.endconv(y2) # (1, 4, 128, 128, 128)
return y
class EnBlock1(nn.Module):
def __init__(self, in_channels, ):
super(EnBlock1, self).__init__()
self.bn1 = nn.BatchNorm3d(in_channels // 4)
self.relu1 = nn.ReLU(inplace=True)
self.bn2 = nn.BatchNorm3d(in_channels // 4)
self.relu2 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1)
self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu1(x1)
x1 = self.conv2(x1)
x1 = self.bn2(x1)
x1 = self.relu2(x1)
return x1
class EnBlock2(nn.Module):
def __init__(self, in_channels):
super(EnBlock2, self).__init__()
self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm3d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.bn2 = nn.BatchNorm3d(in_channels)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu1(x1)
x1 = self.conv2(x1)
x1 = self.bn2(x1)
x1 = self.relu2(x1)
x1 = x1 + x
return x1
class DeUp_Cat(nn.Module):
def __init__(self, in_channels, out_channels):
super(DeUp_Cat, self).__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1)
def forward(self, x, prev):
x1 = self.conv1(x)
y = self.conv2(x1)
# y = y + prev
y = torch.cat((prev, y), dim=1)
y = self.conv3(y)
return y
class DeBlock(nn.Module):
def __init__(self, in_channels):
super(DeBlock, self).__init__()
self.bn1 = nn.BatchNorm3d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(in_channels)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu1(x1)
x1 = self.conv2(x1)
x1 = self.bn2(x1)
x1 = self.relu2(x1)
x1 = x1 + x
return x1
def transbts(in_channels, num_classes, **kwargs):
model = BTS(in_channels, num_classes, img_shape=kwargs['img_shape'])
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 64, 96, 64).long()
# model = transbts(1, 10, img_shape=(64, 96, 64))
# model.train()
# input = torch.rand(2, 1, 64, 96, 64)
# output = model(input)
# loss_train = criterion(output, mask)
# loss_train.backward()
# output = output.data.cpu().numpy()
# print(output.shape)
# print(loss_train)
================================================
FILE: models/networks_3d/unet3d.py
================================================
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class UNet3D(nn.Module):
def __init__(self, in_channels=1, out_channels=3, init_features=64):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
"""
super(UNet3D, self).__init__()
features = init_features
self.encoder1 = UNet3D._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder2 = UNet3D._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose3d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose3d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose3d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose3d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet3D._block(features * 2, features, name="dec1")
self.conv = nn.Conv3d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
outputs = self.conv(dec1)
return outputs
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
class UNet3D_min(nn.Module):
def __init__(self, in_channels=1, out_channels=3, init_features=32):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
"""
super(UNet3D_min, self).__init__()
features = init_features
self.encoder1 = UNet3D._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder2 = UNet3D._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose3d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose3d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose3d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose3d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet3D._block(features * 2, features, name="dec1")
self.conv = nn.Conv3d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
outputs = self.conv(dec1)
return outputs
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
def unet3d(in_channels, num_classes):
model = UNet3D(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def unet3d_min(in_channels, num_classes):
model = UNet3D_min(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet3d(1,10)
# model.eval()
# input = torch.rand(2, 1, 128, 128, 128)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/unet3d_cct.py
================================================
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import init
from torch.distributions.uniform import Uniform
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class FeatureNoise(nn.Module):
def __init__(self, uniform_range=0.3):
super(FeatureNoise, self).__init__()
self.uni_dist = Uniform(-uniform_range, uniform_range)
def feature_based_noise(self, x):
noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
x_noise = x.mul(noise_vector) + x
return x_noise
def forward(self, x):
x = self.feature_based_noise(x)
return x
def Dropout(x, p=0.3):
x = torch.nn.functional.dropout(x, p)
return x
def FeatureDropout(x):
attention = torch.mean(x, dim=1, keepdim=True)
max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
threshold = max_val * np.random.uniform(0.7, 0.9)
threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention)
drop_mask = (attention < threshold).float()
x = x.mul(drop_mask)
return x
class Decoder(nn.Module):
def __init__(self, features, out_channels):
super(Decoder, self).__init__()
self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2)
self.decoder4 = Decoder._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)
self.decoder3 = Decoder._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)
self.decoder2 = Decoder._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)
self.decoder1 = Decoder._block(features * 2, features, name="dec1")
self.conv = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1)
def forward(self, x5, x4, x3, x2, x1):
dec4 = self.upconv4(x5)
dec4 = torch.cat((dec4, x4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, x3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, x2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, x1), dim=1)
dec1 = self.decoder1(dec1)
outputs = self.conv(dec1)
return outputs
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
class UNet3D_CCT(nn.Module):
def __init__(self, in_channels=1, out_channels=3, init_features=64):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
"""
super(UNet3D_CCT, self).__init__()
features = init_features
self.encoder1 = UNet3D_CCT._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder2 = UNet3D_CCT._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name="bottleneck")
self.main_decoder = Decoder(features, out_channels)
self.aux_decoder1 = Decoder(features, out_channels)
self.aux_decoder2 = Decoder(features, out_channels)
self.aux_decoder3 = Decoder(features, out_channels)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1)
aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1))
aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1))
aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1))
return main_seg, aux_seg1, aux_seg2, aux_seg3
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
class UNet3D_CCT_min(nn.Module):
def __init__(self, in_channels=1, out_channels=3, init_features=32):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
"""
super(UNet3D_CCT_min, self).__init__()
features = init_features
self.encoder1 = UNet3D_CCT._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder2 = UNet3D_CCT._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name="bottleneck")
self.main_decoder = Decoder(features, out_channels)
self.aux_decoder1 = Decoder(features, out_channels)
self.aux_decoder2 = Decoder(features, out_channels)
self.aux_decoder3 = Decoder(features, out_channels)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1)
aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1))
aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1))
aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1))
return main_seg, aux_seg1, aux_seg2, aux_seg3
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
def unet3d_cct(in_channels, num_classes):
model = UNet3D_CCT(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
def unet3d_cct_min(in_channels, num_classes):
model = UNet3D_CCT_min(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet3d_cct(1,10)
# model.eval()
# input = torch.rand(2, 1, 128, 128, 128)
# output, aux_output1, aux_output2, aux_output3 = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/unet3d_dtc.py
================================================
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import init
# from loss.loss_function import segmentation_loss
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class UNet3D_DTC(nn.Module):
def __init__(self, in_channels=1, out_channels=3, init_features=64):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
"""
super(UNet3D_DTC, self).__init__()
features = init_features
self.encoder1 = UNet3D_DTC._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder2 = UNet3D_DTC._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder3 = UNet3D_DTC._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
self.encoder4 = UNet3D_DTC._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
self.bottleneck = UNet3D_DTC._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2)
self.decoder4 = UNet3D_DTC._block((features * 8) * 2 , features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)
self.decoder3 = UNet3D_DTC._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)
self.decoder2 = UNet3D_DTC._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)
self.decoder1 = UNet3D_DTC._block(features * 2, features, name="dec1")
self.out_sdf = nn.Sequential(
nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1),
nn.Tanh()
)
self.out_seg = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
out_sdf = self.out_sdf(dec1)
out_seg = self.out_seg(dec1)
return out_sdf, out_seg
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=True,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
def unet3d_dtc(in_channels, num_classes):
model = UNet3D_DTC(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 96, 96, 96).long()
# model = unet3d_dtc(1, 10)
# model.train()
# input1 = torch.rand(2,1,96,96,96)
# out_sdf, out_seg = model(input1)
# loss_train = criterion(out_sdf, mask)
# loss_train.backward()
# # print(output)
# print(out_sdf.data.cpu().numpy().shape)
# print(out_seg.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: models/networks_3d/unet3d_urpc.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class UnetConv3(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)):
super(UnetConv3, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),
nn.InstanceNorm3d(out_size),
nn.ReLU(inplace=True),)
self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
nn.InstanceNorm3d(out_size),
nn.ReLU(inplace=True),)
else:
self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),
nn.ReLU(inplace=True),)
self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
nn.ReLU(inplace=True),)
# initialise the blocks
# for m in self.children():
# init_weights(m, init_type='kaiming')
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class UnetUp3(nn.Module):
def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
super(UnetUp3, self).__init__()
if is_deconv:
self.conv = UnetConv3(in_size, out_size, is_batchnorm)
self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0))
else:
self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm)
self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear', align_corners=True)
# initialise the blocks
# for m in self.children():
# if m.__class__.__name__.find('UnetConv3') != -1: continue
# init_weights(m, init_type='kaiming')
def forward(self, inputs1, inputs2):
outputs2 = self.up(inputs2)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2, 0]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], 1))
class UnetUp3_CT(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm=True):
super(UnetUp3_CT, self).__init__()
self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=True)
# initialise the blocks
# for m in self.children():
# if m.__class__.__name__.find('UnetConv3') != -1: continue
# init_weights(m, init_type='kaiming')
def forward(self, inputs1, inputs2):
outputs2 = self.up(inputs2)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2, 0]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], 1))
class UnetDsv3(nn.Module):
def __init__(self, in_size, out_size, scale_factor):
super(UnetDsv3, self).__init__()
self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0),
nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True), )
def forward(self, input):
return self.dsv(input)
class unet_3D_dv_semi(nn.Module):
def __init__(self, in_channels=3, n_classes=21, feature_scale=4, is_deconv=True, is_batchnorm=True):
super(unet_3D_dv_semi, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# downsampling
self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(
3, 3, 3), padding_size=(1, 1, 1))
self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(
3, 3, 3), padding_size=(1, 1, 1))
self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(
3, 3, 3), padding_size=(1, 1, 1))
self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(
3, 3, 3), padding_size=(1, 1, 1))
self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(
3, 3, 3), padding_size=(1, 1, 1))
# upsampling
self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
# deep supervision
self.dsv4 = UnetDsv3(
in_size=filters[3], out_size=n_classes, scale_factor=8)
self.dsv3 = UnetDsv3(
in_size=filters[2], out_size=n_classes, scale_factor=4)
self.dsv2 = UnetDsv3(
in_size=filters[1], out_size=n_classes, scale_factor=2)
self.dsv1 = nn.Conv3d(
in_channels=filters[0], out_channels=n_classes, kernel_size=1)
self.dropout1 = nn.Dropout3d(p=0.5)
self.dropout2 = nn.Dropout3d(p=0.3)
self.dropout3 = nn.Dropout3d(p=0.2)
self.dropout4 = nn.Dropout3d(p=0.1)
# initialise weights
# for m in self.modules():
# if isinstance(m, nn.Conv3d):
# init_weights(m, init_type='kaiming')
# elif isinstance(m, nn.BatchNorm3d):
# init_weights(m, init_type='kaiming')
def forward(self, inputs):
conv1 = self.conv1(inputs)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
center = self.center(maxpool4)
up4 = self.up_concat4(conv4, center)
up4 = self.dropout1(up4)
up3 = self.up_concat3(conv3, up4)
up3 = self.dropout2(up3)
up2 = self.up_concat2(conv2, up3)
up2 = self.dropout3(up2)
up1 = self.up_concat1(conv1, up2)
up1 = self.dropout4(up1)
# Deep Supervision
dsv4 = self.dsv4(up4)
dsv3 = self.dsv3(up3)
dsv2 = self.dsv2(up2)
dsv1 = self.dsv1(up1)
return dsv1, dsv2, dsv3, dsv4
@staticmethod
def apply_argmax_softmax(pred):
log_p = F.softmax(pred, dim=1)
return log_p
def unet3d_urpc(in_channels, num_classes):
model = unet_3D_dv_semi(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unet3d_urpc(1,10)
# model.eval()
# input = torch.rand(2, 1, 128, 128, 128)
# output, output2, output3, output4 = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/unetr.py
================================================
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class SingleDeconv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super().__init__()
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
def forward(self, x):
return self.block(x)
class SingleConv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size):
super().__init__()
self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2))
def forward(self, x):
return self.block(x)
class Conv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
super().__init__()
self.block = nn.Sequential(
SingleConv3DBlock(in_planes, out_planes, kernel_size),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)
def forward(self, x):
return self.block(x)
class Deconv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
super().__init__()
self.block = nn.Sequential(
SingleDeconv3DBlock(in_planes, out_planes),
SingleConv3DBlock(out_planes, out_planes, kernel_size),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)
def forward(self, x):
return self.block(x)
class SelfAttention(nn.Module):
def __init__(self, num_heads, embed_dim, dropout):
super().__init__()
self.num_attention_heads = num_heads
self.attention_head_size = int(embed_dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(embed_dim, self.all_head_size)
self.key = nn.Linear(embed_dim, self.all_head_size)
self.value = nn.Linear(embed_dim, self.all_head_size)
self.out = nn.Linear(embed_dim, embed_dim)
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
self.vis = False
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
# weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
# return attention_output, weights
return attention_output
# class Mlp(nn.Module):
# def __init__(self, in_features, act_layer=nn.GELU, drop=0.):
# super().__init__()
# self.fc1 = nn.Linear(in_features, in_features)
# self.act = act_layer()
# self.drop = nn.Dropout(drop)
#
# def forward(self, x):
# x = self.fc1()
# x = self.act(x)
# x = self.drop(x)
# return x
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model=786, d_ff=2048, dropout=0.1):
super().__init__()
# Torch linears have a `b` by default.
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Embeddings(nn.Module):
def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout):
super().__init__()
self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim,
kernel_size=patch_size, stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size):
super().__init__()
self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-6)
self.mlp_norm = nn.LayerNorm(embed_dim, eps=1e-6)
self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
self.mlp = PositionwiseFeedForward(embed_dim, 2048)
self.attn = SelfAttention(num_heads, embed_dim, dropout)
def forward(self, x):
h = x
x = self.attention_norm(x)
# x, weights = self.attn(x)
x = self.attn(x)
x = x + h
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = x + h
# return x, weights
return x
class Transformer(nn.Module):
def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_heads, num_layers, dropout, extract_layers):
super().__init__()
self.embeddings = Embeddings(input_dim, embed_dim, cube_size, patch_size, dropout)
self.layer = nn.ModuleList()
self.encoder_norm = nn.LayerNorm(embed_dim, eps=1e-6)
self.extract_layers = extract_layers
for _ in range(num_layers):
layer = TransformerBlock(embed_dim, num_heads, dropout, cube_size, patch_size)
self.layer.append(copy.deepcopy(layer))
def forward(self, x):
extract_layers = []
hidden_states = self.embeddings(x)
for depth, layer_block in enumerate(self.layer):
# hidden_states, _ = layer_block(hidden_states)
hidden_states = layer_block(hidden_states)
if depth + 1 in self.extract_layers:
extract_layers.append(hidden_states)
return extract_layers
class UNETR(nn.Module):
def __init__(self, input_dim=4, output_dim=3, img_shape=(128, 128, 128), embed_dim=768, patch_size=16, num_heads=12,
dropout=0.1):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.embed_dim = embed_dim
self.img_shape = img_shape
self.patch_size = patch_size
self.num_heads = num_heads
self.dropout = dropout
self.num_layers = 12
self.ext_layers = [3, 6, 9, 12]
self.patch_dim = [int(x / patch_size) for x in img_shape]
# Transformer Encoder
self.transformer = \
Transformer(
input_dim,
embed_dim,
img_shape,
patch_size,
num_heads,
self.num_layers,
dropout,
self.ext_layers
)
# U-Net Decoder
self.decoder0 = \
nn.Sequential(
Conv3DBlock(input_dim, 32, 3),
Conv3DBlock(32, 64, 3)
)
self.decoder3 = \
nn.Sequential(
Deconv3DBlock(embed_dim, 512),
Deconv3DBlock(512, 256),
Deconv3DBlock(256, 128)
)
self.decoder6 = \
nn.Sequential(
Deconv3DBlock(embed_dim, 512),
Deconv3DBlock(512, 256),
)
self.decoder9 = \
Deconv3DBlock(embed_dim, 512)
self.decoder12_upsampler = \
SingleDeconv3DBlock(embed_dim, 512)
self.decoder9_upsampler = \
nn.Sequential(
Conv3DBlock(1024, 512),
Conv3DBlock(512, 512),
Conv3DBlock(512, 512),
SingleDeconv3DBlock(512, 256)
)
self.decoder6_upsampler = \
nn.Sequential(
Conv3DBlock(512, 256),
Conv3DBlock(256, 256),
SingleDeconv3DBlock(256, 128)
)
self.decoder3_upsampler = \
nn.Sequential(
Conv3DBlock(256, 128),
Conv3DBlock(128, 128),
SingleDeconv3DBlock(128, 64)
)
self.decoder0_header = \
nn.Sequential(
Conv3DBlock(128, 64),
Conv3DBlock(64, 64),
SingleConv3DBlock(64, output_dim, 1)
)
def forward(self, x):
z = self.transformer(x)
z0, z3, z6, z9, z12 = x, *z
z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
z12 = self.decoder12_upsampler(z12)
z9 = self.decoder9(z9)
z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))
z6 = self.decoder6(z6)
z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))
z3 = self.decoder3(z3)
z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))
z0 = self.decoder0(z0)
output = self.decoder0_header(torch.cat([z0, z3], dim=1))
return output
def unertr(in_channels, num_classes, **kwargs):
model = UNETR(in_channels, num_classes, img_shape=kwargs['img_shape'])
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = unertr(1,10, img_shape=(96, 96, 96))
# model.eval()
# input = torch.rand(2, 1, 96, 96, 96)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/vnet.py
================================================
import torch
import torch.nn as nn
import os
import numpy as np
from collections import OrderedDict
from torch.nn import init
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def passthrough(x, **kwargs):
return x
def ELUCons(elu, nchan):
if elu:
return nn.ELU(inplace=True)
else:
return nn.PReLU(nchan)
class LUConv(nn.Module):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu):
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class InputTransition(nn.Module):
def __init__(self, in_channels, elu):
super(InputTransition, self).__init__()
self.num_features = 16
self.in_channels = in_channels
self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(self.num_features)
self.relu1 = ELUCons(elu, self.num_features)
def forward(self, x):
out = self.conv1(x)
repeat_rate = int(self.num_features / self.in_channels)
out = self.bn1(out)
x16 = x.repeat(1, repeat_rate, 1, 1, 1)
return self.relu1(torch.add(out, x16))
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False):
super(DownTransition, self).__init__()
outChans = 2 * inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, in_channels, classes, elu):
super(OutputTransition, self).__init__()
self.classes = classes
self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(classes)
self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
self.relu1 = ELUCons(elu, classes)
def forward(self, x):
# convolve 32 down to channels as the desired classes
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
return out
class VNet(nn.Module):
"""
Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
"""
def __init__(self, in_channels=1, classes=1, elu=True):
super(VNet, self).__init__()
self.classes = classes
self.in_channels = in_channels
self.in_tr = InputTransition(in_channels, elu=elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, classes, elu)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
return out
def vnet(in_channels, num_classes):
model = VNet(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
# model = vnet(1,10)
# model.eval()
# input = torch.rand(2, 1, 128, 128, 128)
# output = model(input)
# output = output.data.cpu().numpy()
# # print(output)
# print(output.shape)
================================================
FILE: models/networks_3d/vnet_cct.py
================================================
import torch
import torch.nn as nn
import os
import numpy as np
from collections import OrderedDict
from torch.nn import init
from torch.distributions.uniform import Uniform
# from loss.loss_function import segmentation_loss
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class FeatureNoise(nn.Module):
def __init__(self, uniform_range=0.3):
super(FeatureNoise, self).__init__()
self.uni_dist = Uniform(-uniform_range, uniform_range)
def feature_based_noise(self, x):
noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
x_noise = x.mul(noise_vector) + x
return x_noise
def forward(self, x):
x = self.feature_based_noise(x)
return x
def Dropout(x, p=0.3):
x = torch.nn.functional.dropout(x, p)
return x
def FeatureDropout(x):
attention = torch.mean(x, dim=1, keepdim=True)
max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
threshold = max_val * np.random.uniform(0.7, 0.9)
threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention)
drop_mask = (attention < threshold).float()
x = x.mul(drop_mask)
return x
def passthrough(x, **kwargs):
return x
def ELUCons(elu, nchan):
if elu:
return nn.ELU(inplace=True)
else:
return nn.PReLU(nchan)
class LUConv(nn.Module):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu):
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class InputTransition(nn.Module):
def __init__(self, in_channels, elu):
super(InputTransition, self).__init__()
self.num_features = 16
self.in_channels = in_channels
self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(self.num_features)
self.relu1 = ELUCons(elu, self.num_features)
def forward(self, x):
out = self.conv1(x)
repeat_rate = int(self.num_features / self.in_channels)
out = self.bn1(out)
x16 = x.repeat(1, repeat_rate, 1, 1, 1)
return self.relu1(torch.add(out, x16))
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False):
super(DownTransition, self).__init__()
outChans = 2 * inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, in_channels, classes, elu):
super(OutputTransition, self).__init__()
self.classes = classes
self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(classes)
self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
self.relu1 = ELUCons(elu, classes)
def forward(self, x):
# convolve 32 down to channels as the desired classes
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
return out
class Decoder(nn.Module):
def __init__(self, out_channels, elu):
super(Decoder, self).__init__()
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, out_channels, elu)
def forward(self, out256, out128, out64, out32, out16):
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
return out
class VNet_CCT(nn.Module):
"""
Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
"""
def __init__(self, in_channels=1, classes=1, elu=True):
super(VNet_CCT, self).__init__()
self.classes = classes
self.in_channels = in_channels
self.in_tr = InputTransition(in_channels, elu=elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
self.main_decoder = Decoder(classes, elu)
self.aux_decoder1 = Decoder(classes, elu)
self.aux_decoder2 = Decoder(classes, elu)
self.aux_decoder3 = Decoder(classes, elu)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
main_seg = self.main_decoder(out256, out128, out64, out32, out16)
aux_seg1 = self.main_decoder(FeatureNoise()(out256), FeatureNoise()(out128), FeatureNoise()(out64), FeatureNoise()(out32), FeatureNoise()(out16))
aux_seg2 = self.main_decoder(Dropout(out256), Dropout(out128), Dropout(out64), Dropout(out32), Dropout(out16))
aux_seg3 = self.main_decoder(FeatureDropout(out256), FeatureDropout(out128), FeatureDropout(out64), FeatureDropout(out32), FeatureDropout(out16))
return main_seg, aux_seg1, aux_seg2, aux_seg3
def vnet_cct(in_channels, num_classes):
model = VNet_CCT(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 64, 96, 64).long()
# model = vnet_cct(1, 10)
# model.train()
# input = torch.rand(2, 1, 64, 96, 64)
# output, output1, output2, output3 = model(input)
# loss_train = criterion(output, mask)
# loss_train.backward()
# output = output.data.cpu().numpy()
# print(output.shape)
# print(loss_train)
================================================
FILE: models/networks_3d/vnet_dtc.py
================================================
import torch
import torch.nn as nn
import os
import numpy as np
from collections import OrderedDict
from torch.nn import init
# from loss.loss_function import segmentation_loss
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def passthrough(x, **kwargs):
return x
def ELUCons(elu, nchan):
if elu:
return nn.ELU(inplace=True)
else:
return nn.PReLU(nchan)
class LUConv(nn.Module):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu):
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class InputTransition(nn.Module):
def __init__(self, in_channels, elu):
super(InputTransition, self).__init__()
self.num_features = 16
self.in_channels = in_channels
self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(self.num_features)
self.relu1 = ELUCons(elu, self.num_features)
def forward(self, x):
out = self.conv1(x)
repeat_rate = int(self.num_features / self.in_channels)
out = self.bn1(out)
x16 = x.repeat(1, repeat_rate, 1, 1, 1)
return self.relu1(torch.add(out, x16))
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False):
super(DownTransition, self).__init__()
outChans = 2 * inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, in_channels, classes, elu):
super(OutputTransition, self).__init__()
self.classes = classes
self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(classes)
self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
self.relu1 = ELUCons(elu, classes)
def forward(self, x):
# convolve 32 down to channels as the desired classes
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
return out
class VNet_DTC(nn.Module):
"""
Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
"""
def __init__(self, in_channels=1, classes=1, elu=True):
super(VNet_DTC, self).__init__()
self.classes = classes
self.in_channels = in_channels
self.in_tr = InputTransition(in_channels, elu=elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, 16, elu)
self.out_sdf = nn.Sequential(
nn.Conv3d(16, classes, 1, padding=0),
nn.Tanh()
)
self.out_seg = nn.Conv3d(16, classes, 1, padding=0)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
out_sdf = self.out_sdf(out)
out_seg = self.out_seg(out)
return out_sdf, out_seg
def vnet_dtc(in_channels, num_classes):
model = VNet_DTC(in_channels, num_classes)
init_weights(model, 'kaiming')
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 96, 96, 96).long()
# model = vnet_dtc(1, 10)
# model.train()
# input1 = torch.rand(2,1,96,96,96)
# out_sdf, out_seg = model(input1)
# loss_train = criterion(out_sdf, mask)
# loss_train.backward()
# # print(output)
# print(out_sdf.data.cpu().numpy().shape)
# print(out_seg.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: models/networks_3d/xnet3d.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
# from loss.loss_function import segmentation_loss
# BN
BatchNorm3d = nn.InstanceNorm3d
BN_MOMENTUM = 0.1
# BN_MOMENTUM = 0.01
# AF
relu_inplace = True
ActivationFunction = nn.ReLU
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
BatchNorm3d(ch_out, momentum=BN_MOMENTUM),
ActivationFunction(inplace=relu_inplace)
)
def forward(self, x):
x = self.up(x)
return x
class down_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(down_conv, self).__init__()
self.down = nn.Sequential(
nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False),
BatchNorm3d(ch_out, momentum=BN_MOMENTUM),
ActivationFunction(inplace=relu_inplace)
)
def forward(self, x):
x = self.down(x)
return x
class same_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(same_conv, self).__init__()
self.same = nn.Sequential(
nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm3d(ch_out, momentum=BN_MOMENTUM),
ActivationFunction(inplace=relu_inplace)
)
def forward(self, x):
x = self.same(x)
return x
class transition_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(transition_conv, self).__init__()
self.transition = nn.Sequential(
nn.Conv3d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False),
BatchNorm3d(ch_out, momentum=BN_MOMENTUM),
ActivationFunction(inplace=relu_inplace)
)
def forward(self, x):
x = self.transition(x)
return x
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = BatchNorm3d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM)
self.relu = ActivationFunction(inplace=relu_inplace)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
# out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.bn2(out) + identity
out = self.relu(out)
return out
class DoubleBasicBlock(nn.Module):
def __init__(self, inplanes, planes, downsample=None):
super(DoubleBasicBlock, self).__init__()
self.DBB = nn.Sequential(
BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample),
BasicBlock(inplanes=planes, planes=planes)
)
def forward(self, x):
out = self.DBB(x)
return out
class XNet3D(nn.Module):
def __init__(self, in_channels, num_classes):
super(XNet3D, self).__init__()
# l1c, l2c, l3c, l4c = 64, 128, 256, 512
# l1c, l2c, l3c, l4c, l5c = 8, 16, 32, 64, 128
# l1c, l2c, l3c, l4c, l5c = 16, 32, 64, 128, 256
l1c, l2c, l3c, l4c, l5c = 32, 64, 128, 256, 512
# branch1
# branch1_layer1
self.b1_1_1 = nn.Sequential(
conv3x3(in_channels, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b1_1_2_down = down_conv(l1c, l2c)
self.b1_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM)))
self.b1_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch1_layer2
self.b1_2_1 = BasicBlock(l2c, l2c)
self.b1_2_2_down = down_conv(l2c, l3c)
self.b1_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM)))
self.b1_2_4_up = up_conv(l2c, l1c)
# branch1_layer3
self.b1_3_1 = BasicBlock(l3c, l3c)
self.b1_3_2_down = down_conv(l3c, l4c)
self.b1_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM)))
self.b1_3_4_up = up_conv(l3c, l2c)
# branch1_layer4
self.b1_4_1 = BasicBlock(l4c, l4c)
self.b1_4_2_down = down_conv(l4c, l5c)
self.b1_4_2 = BasicBlock(l4c, l4c)
self.b1_4_3_down = down_conv(l4c, l4c)
self.b1_4_3_same = same_conv(l4c, l4c)
self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b1_4_5 = BasicBlock(l4c, l4c)
self.b1_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM)))
self.b1_4_7_up = up_conv(l4c, l3c)
# branch1_layer5
self.b1_5_1 = BasicBlock(l5c, l5c)
self.b1_5_2_up = up_conv(l5c, l5c)
self.b1_5_2_same = same_conv(l5c, l5c)
self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b1_5_4 = BasicBlock(l5c, l5c)
self.b1_5_5_up = up_conv(l5c, l4c)
# branch2
# branch2_layer1
self.b2_1_1 = nn.Sequential(
conv3x3(1, l1c),
conv3x3(l1c, l1c),
BasicBlock(l1c, l1c)
)
self.b2_1_2_down = down_conv(l1c, l2c)
self.b2_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM)))
self.b2_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0)
# branch2_layer2
self.b2_2_1 = BasicBlock(l2c, l2c)
self.b2_2_2_down = down_conv(l2c, l3c)
self.b2_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM)))
self.b2_2_4_up = up_conv(l2c, l1c)
# branch2_layer3
self.b2_3_1 = BasicBlock(l3c, l3c)
self.b2_3_2_down = down_conv(l3c, l4c)
self.b2_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM)))
self.b2_3_4_up = up_conv(l3c, l2c)
# branch2_layer4
self.b2_4_1 = BasicBlock(l4c, l4c)
self.b2_4_2_down = down_conv(l4c, l5c)
self.b2_4_2 = BasicBlock(l4c, l4c)
self.b2_4_3_down = down_conv(l4c, l4c)
self.b2_4_3_same = same_conv(l4c, l4c)
self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)
self.b2_4_5 = BasicBlock(l4c, l4c)
self.b2_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM)))
self.b2_4_7_up = up_conv(l4c, l3c)
# branch2_layer5
self.b2_5_1 = BasicBlock(l5c, l5c)
self.b2_5_2_up = up_conv(l5c, l5c)
self.b2_5_2_same = same_conv(l5c, l5c)
self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)
self.b2_5_4 = BasicBlock(l5c, l5c)
self.b2_5_5_up = up_conv(l5c, l4c)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.BatchNorm3d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABNSync):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, InPlaceABN):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, input1, input2):
# code
# branch1
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_1_2_down(x1_1)
x1_2 = self.b1_2_1(x1_2)
x1_3 = self.b1_2_2_down(x1_2)
x1_3 = self.b1_3_1(x1_3)
x1_4_1 = self.b1_3_2_down(x1_3)
x1_4_1 = self.b1_4_1(x1_4_1)
x1_4_2 = self.b1_4_2(x1_4_1)
x1_4_3_down = self.b1_4_3_down(x1_4_2)
x1_4_3_same = self.b1_4_3_same(x1_4_2)
x1_5_1 = self.b1_4_2_down(x1_4_1)
x1_5_1 = self.b1_5_1(x1_5_1)
x1_5_2_up = self.b1_5_2_up(x1_5_1)
x1_5_2_same = self.b1_5_2_same(x1_5_1)
# branch2
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_1_2_down(x2_1)
x2_2 = self.b2_2_1(x2_2)
x2_3 = self.b2_2_2_down(x2_2)
x2_3 = self.b2_3_1(x2_3)
x2_4_1 = self.b2_3_2_down(x2_3)
x2_4_1 = self.b2_4_1(x2_4_1)
x2_4_2 = self.b2_4_2(x2_4_1)
x2_4_3_down = self.b2_4_3_down(x2_4_2)
x2_4_3_same = self.b2_4_3_same(x2_4_2)
x2_5_1 = self.b2_4_2_down(x2_4_1)
x2_5_1 = self.b2_5_1(x2_5_1)
x2_5_2_up = self.b2_5_2_up(x2_5_1)
x2_5_2_same = self.b2_5_2_same(x2_5_1)
# merge
# branch1
x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)
x1_5_3 = self.b1_5_3_transition(x1_5_3)
x1_5_3 = self.b1_5_4(x1_5_3)
x1_5_3 = self.b1_5_5_up(x1_5_3)
x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)
x1_4_4 = self.b1_4_4_transition(x1_4_4)
x1_4_4 = self.b1_4_5(x1_4_4)
x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)
x1_4_4 = self.b1_4_6(x1_4_4)
x1_4_4 = self.b1_4_7_up(x1_4_4)
# branch2
x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)
x2_5_3 = self.b2_5_3_transition(x2_5_3)
x2_5_3 = self.b2_5_4(x2_5_3)
x2_5_3 = self.b2_5_5_up(x2_5_3)
x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)
x2_4_4 = self.b2_4_4_transition(x2_4_4)
x2_4_4 = self.b2_4_5(x2_4_4)
x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)
x2_4_4 = self.b2_4_6(x2_4_4)
x2_4_4 = self.b2_4_7_up(x2_4_4)
# decode
# branch1
x1_3 = torch.cat((x1_3, x1_4_4), dim=1)
x1_3 = self.b1_3_3(x1_3)
x1_3 = self.b1_3_4_up(x1_3)
x1_2 = torch.cat((x1_2, x1_3), dim=1)
x1_2 = self.b1_2_3(x1_2)
x1_2 = self.b1_2_4_up(x1_2)
x1_1 = torch.cat((x1_1, x1_2), dim=1)
x1_1 = self.b1_1_3(x1_1)
x1_1 = self.b1_1_4(x1_1)
# branch2
x2_3 = torch.cat((x2_3, x2_4_4), dim=1)
x2_3 = self.b2_3_3(x2_3)
x2_3 = self.b2_3_4_up(x2_3)
x2_2 = torch.cat((x2_2, x2_3), dim=1)
x2_2 = self.b2_2_3(x2_2)
x2_2 = self.b2_2_4_up(x2_2)
x2_1 = torch.cat((x2_1, x2_2), dim=1)
x2_1 = self.b2_1_3(x2_1)
x2_1 = self.b2_1_4(x2_1)
return x1_1, x2_1
def xnet3d(in_channels, num_classes):
model = XNet3D(in_channels, num_classes)
return model
# if __name__ == '__main__':
#
# criterion = segmentation_loss('dice', False)
# mask = torch.ones(2, 96, 96, 96).long()
# model = XNet3D(1, 10)
# model.train()
# input1 = torch.rand(2,1,96,96,96)
# input2 = torch.rand(2,1,96,96,96)
# x1_1_main, x1_1_aux1, x1_1_aux2, x1_1_aux3, x2_1_main, x2_1_aux1, x2_1_aux2, x2_1_aux3 = model(input1, input2)
# loss_train = criterion(x1_1_main, mask)
# loss_train.backward()
# # print(output)
# print(x1_1_main.data.cpu().numpy().shape)
# print(x2_1_main.data.cpu().numpy().shape)
# print(loss_train)
================================================
FILE: requirements.txt
================================================
albumentations==0.5.2
einops==0.4.1
MedPy==0.4.0
numpy==1.20.2
opencv_python==4.2.0.34
opencv_python_headless==4.5.1.48
Pillow==8.0.0
PyWavelets==1.1.1
scikit_image==0.18.1
scikit_learn==1.0.1
scipy==1.4.1
SimpleITK==2.1.0
timm==0.6.7
torch==1.8.0+cu111
torchio==0.18.53
torchvision==0.9.0+cu111
tqdm==4.65.0
visdom==0.1.8.9
================================================
FILE: test.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.train_test_config.train_test_config import print_test_eval, save_test_2d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/GlaS/best_kiunet_Jc_0.7779.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI')
parser.add_argument('--input1', default='image')
parser.add_argument('--if_mask', default=True)
parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400')
parser.add_argument('-ds', '--deep_supervision', default=False)
parser.add_argument('-b', '--batch_size', default=4, type=int)
parser.add_argument('-n', '--network', default='kiunet')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# print(path_seg_results)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
# Dataset
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None
)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
num_batches = {'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
inputs_test = data['image']
inputs_test = Variable(inputs_test.cuda(non_blocking=True))
name_test = data['ID']
if args.if_mask:
mask_test = data['mask']
mask_test = Variable(mask_test.cuda(non_blocking=True))
outputs_test = model(inputs_test)
if args.deep_supervision:
outputs_test = outputs_test[0]
if args.if_mask:
if i == 0:
score_list_test = outputs_test
name_list_test = name_test
mask_list_test = mask_test
else:
# elif 0 < i <= num_batches['val'] / 16:
score_list_test = torch.cat((score_list_test, outputs_test), dim=0)
name_list_test = np.append(name_list_test, name_test, axis=0)
mask_list_test = torch.cat((mask_list_test, mask_test), dim=0)
torch.cuda.empty_cache()
else:
save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.if_mask:
score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_test, score_list_test)
score_list_test = torch.cat(score_gather_list_test, dim=0)
mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_test, mask_list_test)
mask_list_test = torch.cat(mask_gather_list_test, dim=0)
name_gather_list_test = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_test, name_list_test)
name_list_test = np.concatenate(name_gather_list_test, axis=0)
if args.if_mask and rank == args.rank_index:
print('=' * print_num)
test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus)
save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: test_3d.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_3d
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from config.train_test_config.train_test_config import save_test_3d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_result1_Jc_0.7677.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--threshold', default=None)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('--patch_overlap', default=(56, 56, 16))
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-n', '--network', default='unet3d_min')
parser.add_argument('-ds', '--deep_supervision', default=False)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['test'],
)
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
for i, subject in enumerate(dataset_val.dataset_1):
grid_sampler = tio.inference.GridSampler(
subject=subject,
patch_size=args.patch_size,
patch_overlap=args.patch_overlap
)
# val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)
dataloaders = dict()
dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)
# dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')
with torch.no_grad():
model.eval()
for data in dataloaders['test']:
inputs_test = Variable(data['image'][tio.DATA].cuda())
location_test = data[tio.LOCATION]
outputs_test = model(inputs_test)
if args.deep_supervision:
outputs_test = outputs_test[0]
aggregator.add_batch(outputs_test, location_test)
outputs_tensor = aggregator.get_output_tensor()
save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: test_ConResNet.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_3d
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_iit_conresnet
from config.train_test_config.train_test_config import save_test_3d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/LiTS/best_conresnet_Jc_0.8545.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--input2', default='image_res')
parser.add_argument('--threshold', default=None)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('--patch_overlap', default=(56, 56, 16))
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-n', '--network', default='conresnet')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_val = dataset_iit_conresnet(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
transform_1=data_transform['test'],
)
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
for i, subject in enumerate(dataset_val.dataset_1):
grid_sampler = tio.inference.GridSampler(
subject=subject,
patch_size=args.patch_size,
patch_overlap=args.patch_overlap
)
# val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)
dataloaders = dict()
dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)
# dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')
with torch.no_grad():
model.eval()
for data in dataloaders['test']:
inputs_test = Variable(data['image'][tio.DATA].cuda())
inputs_test_2 = Variable(data['image2'][tio.DATA].cuda())
location_test = data[tio.LOCATION]
outputs_test = model(inputs_test, inputs_test_2)
aggregator.add_batch(outputs_test[0], location_test)
outputs_tensor = aggregator.get_output_tensor()
save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: test_DTC.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_3d
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it_dtc
from config.train_test_config.train_test_config import save_test_3d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_DTC_Jc_0.7594.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--threshold', default=None)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('--patch_overlap', default=(56, 56, 16))
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-n', '--network', default='vnet_dtc')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_val = dataset_it_dtc(
data_dir=args.path_dataset + '/val',
input1=args.input1,
num_classes=cfg['NUM_CLASSES'],
transform_1=data_transform['test'],
)
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
for i, subject in enumerate(dataset_val.dataset_1):
grid_sampler = tio.inference.GridSampler(
subject=subject,
patch_size=args.patch_size,
patch_overlap=args.patch_overlap
)
# val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)
dataloaders = dict()
dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)
# dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')
with torch.no_grad():
model.eval()
for data in dataloaders['test']:
inputs_test = Variable(data['image'][tio.DATA].cuda())
location_test = data[tio.LOCATION]
outputs_test_sdf, outputs_test_seg = model(inputs_test)
aggregator.add_batch(outputs_test_seg, location_test)
outputs_tensor = aggregator.get_output_tensor()
save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: test_xnet.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from config.train_test_config.train_test_config import print_test_eval, save_test_2d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/GlaS/best_result2_Jc_0.7898.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--if_mask', default=True)
parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400')
parser.add_argument('--if_cct', default=False)
parser.add_argument('--result', default='result2', help='result1, result2')
parser.add_argument('-n', '--network', default='xnet')
parser.add_argument('-b', '--batch_size', default=8, type=int)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# Dataset
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
if args.input2 == 'image':
input2_mean = 'MEAN'
input2_std = 'STD'
else:
input2_mean = 'MEAN_' + args.input2
input2_std = 'STD_' + args.input2
data_transforms = data_transform_2d()
data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])
dataset_val = imagefloder_iitnn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
num_batches = {'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
inputs_test = Variable(data['image'].cuda(non_blocking=True))
inputs_wavelet_test = Variable(data['image_2'].cuda(non_blocking=True))
name_test = data['ID']
if args.if_mask:
mask_test = Variable(data['mask'].cuda(non_blocking=True))
if args.if_cct:
outputs_test1, outputs_test1_aux1, outputs_test1_aux2, outputs_test1_aux3, outputs_test2, outputs_test2_aux1, outputs_test2_aux2, outputs_test2_aux3 = model(inputs_test, inputs_wavelet_test)
else:
outputs_test1, outputs_test2 = model(inputs_test, inputs_wavelet_test)
if args.result == 'result1':
outputs_test = outputs_test1
else:
outputs_test = outputs_test2
if args.if_mask:
if i == 0:
score_list_test = outputs_test
name_list_test = name_test
mask_list_test = mask_test
else:
# elif 0 < i <= num_batches['val'] / 16:
score_list_test = torch.cat((score_list_test, outputs_test), dim=0)
name_list_test = np.append(name_list_test, name_test, axis=0)
mask_list_test = torch.cat((mask_list_test, mask_test), dim=0)
torch.cuda.empty_cache()
else:
save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.if_mask:
score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_test, score_list_test)
score_list_test = torch.cat(score_gather_list_test, dim=0)
mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_test, mask_list_test)
mask_list_test = torch.cat(mask_gather_list_test, dim=0)
name_gather_list_test = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_test, name_list_test)
name_list_test = np.concatenate(name_gather_list_test, axis=0)
if args.if_mask and rank == args.rank_index:
print('=' * print_num)
test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus)
save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: test_xnet3d.py
================================================
from torchvision import transforms, datasets
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_3d
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_iit
from config.train_test_config.train_test_config import save_test_3d
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/LiTS/best_result1_Jc_0.7794.pth')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--threshold', default=None)
parser.add_argument('--result', default='result1', help='result1, result2')
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('--patch_overlap', default=(56, 56, 16))
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-n', '--network', default='xnet3d')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
# Config
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
# Results Save
if not os.path.exists(args.path_seg_results) and rank == args.rank_index:
os.mkdir(args.path_seg_results)
path_seg_results = args.path_seg_results + '/' + str(dataset_name)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_val = dataset_iit(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
transform_1=data_transform['test'],
)
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
# if rank == args.rank_index:
# state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))
# model.load_state_dict(state_dict=state_dict)
# model = DistributedDataParallel(model, device_ids=[args.local_rank])
model = DistributedDataParallel(model, device_ids=[args.local_rank])
state_dict = torch.load(args.path_model)
model.load_state_dict(state_dict=state_dict)
dist.barrier()
# Test
since = time.time()
for i, subject in enumerate(dataset_val.dataset_1):
grid_sampler = tio.inference.GridSampler(
subject=subject,
patch_size=args.patch_size,
patch_overlap=args.patch_overlap
)
# val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)
dataloaders = dict()
dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)
# dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')
with torch.no_grad():
model.eval()
for data in dataloaders['test']:
inputs_test_1 = Variable(data['image'][tio.DATA].cuda())
inputs_test_2 = Variable(data['image2'][tio.DATA].cuda())
location_test = data[tio.LOCATION]
outputs_test_1, outputs_test_2 = model(inputs_test_1, inputs_test_2)
if args.result == 'result1':
outputs_test = outputs_test_1
else:
outputs_test = outputs_test_2
aggregator.add_batch(outputs_test, location_test)
outputs_tensor = aggregator.get_output_tensor()
save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('-' * print_num)
print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('=' * print_num)
================================================
FILE: tools/Atrial/__init__.py
================================================
================================================
FILE: tools/Atrial/postprocess.py
================================================
import numpy as np
import argparse
import os
import SimpleITK as sitk
from skimage.morphology import remove_small_objects, remove_small_holes
import skimage
def save_max_objects(image):
labeled_image = skimage.measure.label(image)
labeled_list = skimage.measure.regionprops(labeled_image)
box = []
for i in range(len(labeled_list)):
box.append(labeled_list[i].area)
label_num = box.index(max(box)) + 1
labeled_image[labeled_image != label_num] = 0
labeled_image[labeled_image == label_num] = 1
return labeled_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730')
parser.add_argument('--save_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730_mor')
parser.add_argument('--fill_hole_thr', default=500, help='300-500')
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
for i in os.listdir(args.pred_path):
pred_path = os.path.join(args.pred_path, i)
save_path = os.path.join(args.save_path, i)
pred = sitk.ReadImage(pred_path)
pred = sitk.GetArrayFromImage(pred)
pred = pred.astype(bool)
pred = remove_small_holes(pred, args.fill_hole_thr)
pred = pred.astype(np.uint8)
pred = save_max_objects(pred)
pred = pred.astype(np.uint8)
pred = sitk.GetImageFromArray(pred)
sitk.WriteImage(pred, save_path)
================================================
FILE: tools/Atrial/preprocess.py
================================================
import numpy as np
import torchio as tio
import os
import argparse
from tqdm import tqdm
import SimpleITK as sitk
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/Training Set')
parser.add_argument('--save_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/dataset')
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
save_image_path = args.save_path + '/image'
save_mask_path = args.save_path + '/mask'
if not os.path.exists(save_image_path):
os.mkdir(save_image_path)
if not os.path.exists(save_mask_path):
os.mkdir(save_mask_path)
for i in os.listdir(args.data_path):
save_name = i + '.nrrd'
image_path = args.data_path + '/' + i + '/' + 'lgemri.nrrd'
mask_path = args.data_path + '/' + i + '/' + 'laendo.nrrd'
image = tio.ScalarImage(image_path)
mask = tio.LabelMap(mask_path)
_, w, h, d = mask.data.shape
tempL = np.nonzero(np.array(mask.data))
minx, maxx = np.min(tempL[1]), np.max(tempL[1])
miny, maxy = np.min(tempL[2]), np.max(tempL[2])
# minz, maxz = np.min(tempL[3]), np.max(tempL[3])
px = max(112 - (maxx - minx), 0) // 2
py = max(112 - (maxy - miny), 0) // 2
# pz = max(80 - (maxz - minz), 0) // 2
minx = max(minx - np.random.randint(10, 20) - px, 0)
maxx = min(maxx + np.random.randint(10, 20) + px, w)
miny = max(miny - np.random.randint(10, 20) - py, 0)
maxy = min(maxy + np.random.randint(10, 20) + py, h)
# minz = max(minz - np.random.randint(5, 10) - pz, 0)
# maxz = min(maxz + np.random.randint(5, 10) + pz, d)
image_np = image.data[:, minx:maxx, miny:maxy, :]
image.set_data(image_np)
mask_np = mask.data[:, minx:maxx, miny:maxy, :]
mask.set_data(mask_np)
print(image_np.shape)
image.save(os.path.join(save_image_path, save_name))
mask.save(os.path.join(save_mask_path, save_name))
================================================
FILE: tools/LiTS/__init__.py
================================================
================================================
FILE: tools/LiTS/postprocess.py
================================================
import numpy as np
import argparse
import os
import SimpleITK as sitk
from skimage.morphology import remove_small_objects, remove_small_holes
import skimage
def save_max_objects(image):
labeled_image = skimage.measure.label(image)
labeled_list = skimage.measure.regionprops(labeled_image)
box = []
for i in range(len(labeled_list)):
box.append(labeled_list[i].area)
label_num = box.index(max(box)) + 1
labeled_image[labeled_image != label_num] = 0
labeled_image[labeled_image == label_num] = 1
return labeled_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594')
parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594_mor')
parser.add_argument('--fill_hole_thr', default=100)
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
for i in os.listdir(args.pred_path):
pred_path = os.path.join(args.pred_path, i)
save_path = os.path.join(args.save_path, i)
pred = sitk.ReadImage(pred_path)
pred = sitk.GetArrayFromImage(pred)
pred_ = pred.copy()
pred_[pred != 0] = 1
pred_ = pred_.astype(bool)
pred_ = remove_small_holes(pred_, args.fill_hole_thr)
pred_ = pred_.astype(np.uint8)
pred_ = save_max_objects(pred_)
pred_[(pred_ == 1) & (pred == 2)] = 2
pred_ = pred_.astype(np.uint8)
pred_ = sitk.GetImageFromArray(pred_)
sitk.WriteImage(pred_, save_path)
================================================
FILE: tools/LiTS/preprocess.py
================================================
import numpy as np
import os
import argparse
from tqdm import tqdm
import SimpleITK as sitk
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='E:/Biomedical datasets/LiTS')
parser.add_argument('--save_path', default='E:/Biomedical datasets/LiTS/dataset')
parser.add_argument('--min_hu', default=-100)
parser.add_argument('--max_hu', default=250)
parser.add_argument('--target_spacing', default=[1.00, 1.00, 1.00])
parser.add_argument('--crop_pixel', default=25)
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
save_image_path = args.save_path + '/image'
save_mask_path = args.save_path + '/mask'
if not os.path.exists(save_image_path):
os.mkdir(save_image_path)
if not os.path.exists(save_mask_path):
os.mkdir(save_mask_path)
image_path = args.data_path + '/image'
mask_path = args.data_path + '/mask'
for i in os.listdir(image_path):
image_dir = os.path.join(image_path, i)
mask_dir = os.path.join(mask_path, i)
image = sitk.ReadImage(image_dir)
mask = sitk.ReadImage(mask_dir)
size = np.array(image.GetSize())
spacing = np.array(image.GetSpacing())
new_size = size * spacing / args.target_spacing
new_size = [int(s) for s in new_size]
print(new_size, size)
resample_image = sitk.ResampleImageFilter()
resample_image.SetOutputDirection(image.GetDirection())
resample_image.SetOutputOrigin(image.GetOrigin())
resample_image.SetSize(new_size)
resample_image.SetOutputSpacing(args.target_spacing)
resample_image.SetInterpolator(sitk.sitkLinear)
image = resample_image.Execute(image)
resample_mask = sitk.ResampleImageFilter()
resample_mask.SetOutputDirection(mask.GetDirection())
resample_mask.SetOutputOrigin(mask.GetOrigin())
resample_mask.SetSize(new_size)
resample_mask.SetOutputSpacing(args.target_spacing)
resample_mask.SetInterpolator(sitk.sitkNearestNeighbor)
mask = resample_mask.Execute(mask)
image_np = sitk.GetArrayFromImage(image)
mask_np = sitk.GetArrayFromImage(mask)
w, h, d = mask_np.shape
templ = np.nonzero(mask_np)
w_min = max(np.min(templ[0]) - args.crop_pixel, 0)
w_max = min(np.max(templ[0]) + args.crop_pixel, w)
h_min = max(np.min(templ[1]) - args.crop_pixel, 0)
h_max = min(np.max(templ[1]) + args.crop_pixel, h)
d_min = max(np.min(templ[2]) - args.crop_pixel, 0)
d_max = min(np.max(templ[2]) + args.crop_pixel, d)
image_np = image_np[w_min:w_max, h_min:h_max, d_min:d_max]
# image_np = image.data
image_np[image_np < args.min_hu] = args.min_hu
image_np[image_np > args.max_hu] = args.max_hu
mask_np = mask_np[w_min:w_max, h_min:h_max, d_min:d_max]
image_save = sitk.GetImageFromArray(image_np)
image_save.SetSpacing(args.target_spacing)
image_save.SetDirection(image.GetDirection())
image_save.SetOrigin(image.GetOrigin())
mask_save = sitk.GetImageFromArray(mask_np)
mask_save.SetSpacing(args.target_spacing)
mask_save.SetDirection(image.GetDirection())
mask_save.SetOrigin(image.GetOrigin())
sitk.WriteImage(image_save, os.path.join(save_image_path, i))
sitk.WriteImage(mask_save, os.path.join(save_mask_path, i))
# image_save.save(os.path.join(save_image_path, save_name))
# mask_save.save(os.path.join(save_mask_path, save_name))
================================================
FILE: tools/LiTS/split_train_val.py
================================================
import numpy as np
import os
import argparse
import shutil
import random
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/image')
parser.add_argument('--mask_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/mask')
parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val')
parser.add_argument('--amount', default=31)
parser.add_argument('--random_seed', default=10)
args = parser.parse_args()
random.seed(args.random_seed)
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
save_image_path = args.save_path + '/image'
save_mask_path = args.save_path + '/mask'
if not os.path.exists(save_image_path):
os.mkdir(save_image_path)
if not os.path.exists(save_mask_path):
os.mkdir(save_mask_path)
image_path_list = os.listdir(args.image_path)
image_path_list = random.sample(image_path_list, args.amount)
for i in image_path_list:
shutil.move(os.path.join(args.image_path, i), save_image_path)
shutil.move(os.path.join(args.mask_path, i), save_mask_path)
================================================
FILE: tools/__init__.py
================================================
================================================
FILE: tools/eval.py
================================================
from sklearn.metrics import confusion_matrix
import numpy as np
import argparse
import os
from PIL import Image
from medpy.metric.binary import hd95, assd
import albumentations as A
import SimpleITK as sitk
def eval_distance(mask_list, seg_result_list, num_classes):
print_num = 42 + (num_classes - 3) * 7
print_num_minus = print_num - 2
assert len(mask_list) == len(seg_result_list)
if num_classes == 2:
hd_list = []
sd_list = []
for i in range(len(mask_list)):
if np.any(seg_result_list[i]) and np.any(mask_list[i]):
hd_ = hd95(seg_result_list[i], mask_list[i])
sd_ = assd(seg_result_list[i], mask_list[i])
hd_list.append(hd_)
sd_list.append(sd_)
hd = np.mean(hd_list)
sd = np.mean(sd_list)
print('| Hd: {:.4f}'.format(hd).ljust(print_num_minus, ' '), '|')
print('| Sd: {:.4f}'.format(sd).ljust(print_num_minus, ' '), '|')
else:
hd_list = []
sd_list = []
for cls in range(num_classes-1):
hd_list_ = []
sd_list_ = []
for i in range(len(mask_list)):
mask_list_ = mask_list[i].copy()
seg_result_list_ = seg_result_list[i].copy()
mask_list_[mask_list[i] != (cls + 1)] = 0
seg_result_list_[seg_result_list[i] != (cls + 1)] = 0
if np.any(seg_result_list_) and np.any(mask_list_):
hd_ = hd95(seg_result_list_, mask_list_)
sd_ = assd(seg_result_list_, mask_list_)
hd_list_.append(hd_)
sd_list_.append(sd_)
hd = np.mean(hd_list_)
sd = np.mean(sd_list_)
hd_list.append(hd)
sd_list.append(sd)
hd_list = np.array(hd_list)
sd_list = np.array(sd_list)
m_hd = np.mean(hd_list)
m_sd = np.mean(sd_list)
np.set_printoptions(precision=4, suppress=True)
print('| Hd: {}'.format(hd_list).ljust(print_num_minus, ' '), '|')
print('| Sd: {}'.format(sd_list).ljust(print_num_minus, ' '), '|')
print('| mHd: {:.4f}'.format(m_hd).ljust(print_num_minus, ' '), '|')
print('| mSd: {:.4f}'.format(m_sd).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
def eval_pixel(mask_list, seg_result_list, num_classes):
c = confusion_matrix(mask_list, seg_result_list)
hist_diag = np.diag(c)
hist_sum_0 = c.sum(axis=0)
hist_sum_1 = c.sum(axis=1)
jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag)
dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0)
print_num = 42 + (num_classes - 3) * 7
print_num_minus = print_num - 2
print('-' * print_num)
if num_classes > 2:
m_jaccard = np.nanmean(jaccard)
m_dice = np.nanmean(dice)
np.set_printoptions(precision=4, suppress=True)
print('| Jc: {}'.format(jaccard).ljust(print_num_minus, ' '), '|')
print('| Dc: {}'.format(dice).ljust(print_num_minus, ' '), '|')
print('| mJc: {:.4f}'.format(m_jaccard).ljust(print_num_minus, ' '), '|')
print('| mDc: {:.4f}'.format(m_dice).ljust(print_num_minus, ' '), '|')
else:
print('| Jc: {:.4f}'.format(jaccard[1]).ljust(print_num_minus, ' '), '|')
print('| Dc: {:.4f}'.format(dice[1]).ljust(print_num_minus, ' '), '|')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred_path', default='/mnt/data1/XNet/seg_pred/test/LiTS/best_result1_Jc_0.7677_mor')
parser.add_argument('--mask_path', default='/mnt/data1/XNet/dataset/LiTS/val/mask')
parser.add_argument('--if_3D', default=True)
parser.add_argument('--resize_shape', default=(128, 128))
parser.add_argument('--num_classes', default=3)
args = parser.parse_args()
pred_list = []
mask_list = []
pred_flatten_list = []
mask_flatten_list = []
num = 0
for i in os.listdir(args.pred_path):
pred_path = os.path.join(args.pred_path, i)
mask_path = os.path.join(args.mask_path, i)
if args.if_3D:
pred = sitk.ReadImage(pred_path)
pred = sitk.GetArrayFromImage(pred)
mask = sitk.ReadImage(mask_path)
mask = sitk.GetArrayFromImage(mask)
else:
pred = Image.open(pred_path)
# pred = pred.resize((args.resize_shape[1], args.resize_shape[0]))
pred = np.array(pred)
mask = Image.open(mask_path)
# mask = mask.resize((args.resize_shape[1], args.resize_shape[0]))
mask = np.array(mask)
resize = A.Resize(args.resize_shape[1], args.resize_shape[0], p=1)(image=pred, mask=mask)
mask = resize['mask']
pred = resize['image']
pred_list.append(pred)
mask_list.append(mask)
if num == 0:
pred_flatten_list = pred.flatten()
mask_flatten_list = mask.flatten()
else:
pred_flatten_list = np.append(pred_flatten_list, pred.flatten())
mask_flatten_list = np.append(mask_flatten_list, mask.flatten())
num += 1
eval_pixel(mask_flatten_list, pred_flatten_list, args.num_classes)
eval_distance(mask_list, pred_list, args.num_classes)
================================================
FILE: tools/mask2sdf.py
================================================
import numpy as np
import os
import argparse
import SimpleITK as sitk
from scipy.ndimage import distance_transform_edt
from skimage import segmentation
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val')
parser.add_argument('--num_classes', default=3)
args = parser.parse_args()
mask_path = args.data_path + '/mask'
for i in range(args.num_classes-1):
save_sdf_mask_path = args.data_path + '/mask_sdf' + str(i+1)
if not os.path.exists(save_sdf_mask_path):
os.mkdir(save_sdf_mask_path)
for j in os.listdir(mask_path):
mask = sitk.ReadImage(os.path.join(mask_path, j))
mask_np = sitk.GetArrayFromImage(mask)
mask_np[mask_np != (i+1)] = 0
mask_np = mask_np.astype(bool)
if mask_np.any():
mask_neg = ~mask_np
posdis = distance_transform_edt(mask_np)
negdis = distance_transform_edt(mask_neg)
boundary = segmentation.find_boundaries(mask_np, mode='inner').astype(np.uint8)
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
sdf[boundary == 1] = 0
# sdf = ((sdf - np.min(sdf)) / (np.max(sdf) - np.min(sdf))) * 255
else:
sdf = np.zeros(mask_np.shape)
sdf = sitk.GetImageFromArray(sdf)
sdf.SetSpacing(mask.GetSpacing())
sdf.SetDirection(mask.GetDirection())
sdf.SetOrigin(mask.GetOrigin())
sitk.WriteImage(sdf, os.path.join(save_sdf_mask_path, j))
================================================
FILE: tools/res_image_mask.py
================================================
import numpy as np
import os
import argparse
import SimpleITK as sitk
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/Atrial/train_sup_100')
args = parser.parse_args()
image_path = args.data_path + '/image'
mask_path = args.data_path + '/mask'
save_res_path = args.data_path + '/image_res'
save_res_mask_path = args.data_path + '/mask_res'
if not os.path.exists(save_res_path):
os.mkdir(save_res_path)
if not os.path.exists(save_res_mask_path):
os.mkdir(save_res_mask_path)
for i in os.listdir(image_path):
image = sitk.ReadImage(os.path.join(image_path, i))
image_np = sitk.GetArrayFromImage(image)
mask = sitk.ReadImage(os.path.join(mask_path, i))
mask_np = sitk.GetArrayFromImage(mask)
image_copy = np.zeros(image_np.shape)
image_copy[1:, :, :] = image_np[0:image_np.shape[0] - 1, :, :]
image_res = image_np - image_copy
image_res[0, :, :] = 0
image_res = np.abs(image_res)
image_res = sitk.GetImageFromArray(image_res)
image_res.SetSpacing(image.GetSpacing())
image_res.SetDirection(image.GetDirection())
image_res.SetOrigin(image.GetOrigin())
mask_copy = np.zeros(mask_np.shape)
mask_copy[1:, :, :] = mask_np[0:mask_np.shape[0] - 1, :, :]
mask_res = mask_np - mask_copy
mask_res[0, :, :] = 0
mask_res = np.abs(mask_res)
mask_res = sitk.GetImageFromArray(mask_res)
mask_res.SetSpacing(image.GetSpacing())
mask_res.SetDirection(image.GetDirection())
mask_res.SetOrigin(image.GetOrigin())
sitk.WriteImage(image_res, os.path.join(save_res_path, i))
sitk.WriteImage(mask_res, os.path.join(save_res_mask_path, i))
================================================
FILE: tools/wavelet2D.py
================================================
import numpy as np
from PIL import Image
import pywt
import argparse
import os
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/image')
parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/L')
parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/H')
parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')
parser.add_argument('--if_RGB', default=False)
args = parser.parse_args()
if not os.path.exists(args.L_path):
os.mkdir(args.L_path)
if not os.path.exists(args.H_path):
os.mkdir(args.H_path)
for i in os.listdir(args.image_path):
image_path = os.path.join(args.image_path, i)
L_path = os.path.join(args.L_path, i)
H_path = os.path.join(args.H_path, i)
if args.if_RGB:
image = Image.open(image_path).convert('L')
else:
image = Image.open(image_path)
image = np.array(image)
LL, (LH, HL, HH) = pywt.dwt2(image, args.wavelet_type)
LL = (LL - LL.min()) / (LL.max() - LL.min()) * 255
LL = Image.fromarray(LL.astype(np.uint8))
LL.save(L_path)
LH = (LH - LH.min()) / (LH.max() - LH.min()) * 255
HL = (HL - HL.min()) / (HL.max() - HL.min()) * 255
HH = (HH - HH.min()) / (HH.max() - HH.min()) * 255
merge1 = HH + HL + LH
merge1 = (merge1-merge1.min()) / (merge1.max()-merge1.min()) * 255
merge1 = Image.fromarray(merge1.astype(np.uint8))
merge1.save(H_path)
================================================
FILE: tools/wavelet3D.py
================================================
import numpy as np
from PIL import Image
import pywt
import argparse
import os
import SimpleITK as sitk
import torchio as tio
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val/image')
parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/L')
parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/H')
parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')
args = parser.parse_args()
if not os.path.exists(args.L_path):
os.mkdir(args.L_path)
if not os.path.exists(args.H_path):
os.mkdir(args.H_path)
for i in os.listdir(args.image_path):
image_path = os.path.join(args.image_path, i)
L_path = os.path.join(args.L_path, i)
H_path = os.path.join(args.H_path, i)
image = sitk.ReadImage(image_path)
image_np = sitk.GetArrayFromImage(image)
image_wave = pywt.dwtn(image_np, args.wavelet_type)
LLL = image_wave['aaa']
LLH = image_wave['aad']
LHL = image_wave['ada']
LHH = image_wave['add']
HLL = image_wave['daa']
HLH = image_wave['dad']
HHL = image_wave['dda']
HHH = image_wave['ddd']
LLL = (LLL - LLL.min()) / (LLL.max() - LLL.min()) * 255
resample_image = sitk.ResampleImageFilter()
resample_image.SetSize(image.GetSize())
resample_image.SetOutputSpacing([0.5, 0.5, 0.5])
resample_image.SetInterpolator(sitk.sitkLinear)
LLL = resample_image.Execute(LLL)
LLL.SetSpacing(image.GetSpacing())
LLL.SetDirection(image.GetDirection())
LLL.SetOrigin(image.GetOrigin())
sitk.WriteImage(LLL, L_path)
LLH = (LLH - LLH.min()) / (LLH.max() - LLH.min()) * 255
LHL = (LHL - LHL.min()) / (LHL.max() - LHL.min()) * 255
LHH = (LHH - LHH.min()) / (LHH.max() - LHH.min()) * 255
HLL = (HLL - HLL.min()) / (HLL.max() - HLL.min()) * 255
HLH = (HLH - HLH.min()) / (HLH.max() - HLH.min()) * 255
HHL = (HHL - HHL.min()) / (HHL.max() - HHL.min()) * 255
HHH = (HHH - HHH.min()) / (HHH.max() - HHH.min()) * 255
merge1 = LLH + LHL + LHH + HLL + HLH + HHL + HHH
merge1 = (merge1 - merge1.min()) / (merge1.max() - merge1.min()) * 255
merge1 = sitk.GetImageFromArray(merge1)
resample_image = sitk.ResampleImageFilter()
resample_image.SetSize(image.GetSize())
resample_image.SetOutputSpacing([0.5, 0.5, 0.5])
resample_image.SetInterpolator(sitk.sitkLinear)
merge1 = resample_image.Execute(merge1)
merge1.SetSpacing(image.GetSpacing())
merge1.SetDirection(image.GetDirection())
merge1.SetOrigin(image.GetOrigin())
sitk.WriteImage(merge1, H_path)
================================================
FILE: train_semi_CCT.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet_cct', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
kl_distance = nn.KLDivLoss(reduction='none')
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
optimizer1.zero_grad()
pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)
pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)
consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2)
consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2)
consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2)
loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup) + criterion(pred_train_sup2, mask_train_sup) + criterion(pred_train_sup3, mask_train_sup) + criterion(pred_train_sup4, mask_train_sup)) / 4
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'CCT')
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_CCT_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=8, type=int)
parser.add_argument('--samples_per_volume_val', default=12, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d_cct', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight)+ '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank], find_unused_parameters=True)
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
kl_distance = nn.KLDivLoss(reduction='none')
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)
pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)
consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2)
consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2)
consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2)
loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'CCT', cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_CPS.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet_sb', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CPS-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_'+args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup = unsup_index['image']
img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True))
optimizer1.zero_grad()
optimizer2.zero_grad()
pred_train_unsup1 = model1(img_train_unsup)
pred_train_unsup2 = model2(img_train_unsup)
max_train1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1 = model1(img_train_sup)
pred_train_sup2 = model2(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer1.step()
optimizer2.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
scheduler_warmup2.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_val1 = model1(inputs_val)
outputs_val2 = model2(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)
visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_CPS_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=8, type=int)
parser.add_argument('--samples_per_volume_val', default=12, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d_min', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CPS-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)
# Train & Val
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
optimizer2.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup2 = model2(img_train_unsup_1)
max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train_unsup1 = max_train_unsup1.long()
max_train_unsup2 = max_train_unsup2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1 = model1(img_train_sup_1)
pred_train_sup2 = model2(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer1.step()
optimizer2.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
scheduler_warmup2.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda().cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_val_1 = model1(inputs_val_1)
outputs_val_2 = model2(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_CT.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n1', '--network1', default='unet', type=str)
parser.add_argument('-n2', '--network2', default='swinunet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup = unsup_index['image']
img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True))
optimizer1.zero_grad()
optimizer2.zero_grad()
pred_train_unsup1 = model1(img_train_unsup)
pred_train_unsup2 = model2(img_train_unsup)
max_train1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1 = model1(img_train_sup)
pred_train_sup2 = model2(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer1.step()
optimizer2.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
scheduler_warmup2.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_val1 = model1(inputs_val)
outputs_val2 = model2(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)
visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_CT_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=8, type=int)
parser.add_argument('--samples_per_volume_val', default=12, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n1', '--network1', default='unet3d_min', type=str)
parser.add_argument('-n2', '--network2', default='unertr', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model1 = model1.cuda()
model2 = model2.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank], find_unused_parameters=True)
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)
# Train & Val
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
optimizer2.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup2 = model2(img_train_unsup_1)
max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train_unsup1 = max_train_unsup1.long()
max_train_unsup2 = max_train_unsup2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1 = model1(img_train_sup_1)
pred_train_sup2 = model2(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer1.step()
optimizer2.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
scheduler_warmup2.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_val_1 = model1(inputs_val_1)
outputs_val_2 = model2(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_DTC.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it_dtc
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')
parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=1, type=float)
parser.add_argument('--beta', default=0.3, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(112, 112, 32))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=8, type=int)
parser.add_argument('--samples_per_volume_val', default=12, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d_dtc', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-DTC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it_dtc(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
num_classes=cfg['NUM_CLASSES'],
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it_dtc(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
num_classes=cfg['NUM_CLASSES'],
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it_dtc(
data_dir=args.path_dataset + '/val',
input1=args.input1,
num_classes=cfg['NUM_CLASSES'],
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
mseloss = torch.nn.MSELoss().cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
pred_train_unsup_sdf, pred_train_unsup_seg = model1(img_train_unsup_1)
pred_train_unsup_seg_soft = torch.sigmoid(pred_train_unsup_seg)
dis_to_mask = torch.sigmoid(-1500 * pred_train_unsup_sdf)
loss_train_unsup = torch.mean((dis_to_mask - pred_train_unsup_seg_soft) ** 2)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
mask_train_sup_sdf1 = Variable(sup_index['mask2'][tio.DATA].squeeze(1).float().cuda())
if cfg['NUM_CLASSES'] == 3:
mask_train_sup_sdf2 = Variable(sup_index['mask3'][tio.DATA].squeeze(1).float().cuda())
pred_train_sup_sdf, pred_train_sup_seg = model1(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup_seg
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup_seg), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
if cfg['NUM_CLASSES'] == 3:
loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1) + mseloss(pred_train_sup_sdf[:, 2, ...], mask_train_sup_sdf2)
else:
loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1)
loss_train_seg = criterion(pred_train_sup_seg, mask_train_sup)
loss_train_sup = loss_train_seg + args.beta * loss_train_sdf
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup += loss_train_sup.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_sdf, outputs_val_seg = model1(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_seg
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_seg), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_seg, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'DTC', cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_EM.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss, entropy_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
loss_train_unsup = entropy_loss(pred_train_unsup1, C=2)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1 = model1(img_train_sup)
pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
outputs_val1 = model1(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'EM')
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_EM_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss, entropy_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=50, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
loss_train_unsup = entropy_loss(pred_train_unsup1, C=2)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1 = model1(img_train_sup_1)
pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_1 = model1(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'EM', cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(
print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_MT.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--ema_decay', default=0.99, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
# for param in model2.parameters():
# param.detach_()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)
img_train_unsup_2 = img_train_unsup_1 + noise
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
with torch.no_grad():
pred_train_unsup2 = model2(img_train_unsup_2)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1 = model1(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
update_ema_variables(model1, model2, args.ema_decay, epoch)
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)
train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
outputs_val1 = model1(inputs_val)
outputs_val2 = model2(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2)
visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_MT_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--ema_decay', default=0.99, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
# for param in model2.parameters():
# param.detach_()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)
img_train_unsup_2 = img_train_unsup_1 + noise
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
with torch.no_grad():
pred_train_unsup2 = model2(img_train_unsup_2)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1 = model1(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
update_ema_variables(model1, model2, args.ema_decay, epoch)
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_1 = model1(inputs_val_1)
outputs_val_2 = model2(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_UAMT.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss, softmax_mse_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best
from warnings import simplefilter
from config.ramps import ramps
simplefilter(action='ignore', category=FutureWarning)
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=0.05, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--ema_decay', default=0.99, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
# for param in model2.parameters():
# param.detach_()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)
img_train_unsup_2 = img_train_unsup_1 + noise
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
with torch.no_grad():
pred_train_unsup2 = model2(img_train_unsup_2)
T = 8
_, _, w, h = img_train_unsup_1.shape
volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1)
stride = volume_batch_r.shape[0] // 2
preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], w, h]).cuda()
for i_ in range(T // 2):
ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)
with torch.no_grad():
preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs)
preds = torch.softmax(preds, dim=1)
preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], w, h)
preds = torch.mean(preds, dim=0)
uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)
consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2) # (batch, 2, 112,112,80)
threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2)
mask = (uncertainty < threshold).float()
loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1 = model1(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
update_ema_variables(model1, model2, args.ema_decay, epoch)
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)
train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
outputs_val1 = model1(inputs_val)
outputs_val2 = model2(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2)
visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_UAMT_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from config.ramps import ramps
from loss.loss_function import segmentation_loss, softmax_mse_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--ema_decay', default=0.99, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
# for param in model2.parameters():
# param.detach_()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)
img_train_unsup_2 = img_train_unsup_1 + noise
optimizer1.zero_grad()
pred_train_unsup1 = model1(img_train_unsup_1)
with torch.no_grad():
pred_train_unsup2 = model2(img_train_unsup_2)
T = 8
_, _, d, w, h = img_train_unsup_1.shape
volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1, 1)
stride = volume_batch_r.shape[0] // 2
preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], d, w, h]).cuda()
for i_ in range(T // 2):
ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)
with torch.no_grad():
preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs)
preds = torch.softmax(preds, dim=1)
preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], d, w, h)
preds = torch.mean(preds, dim=0)
uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)
consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2) # (batch, 2, 112,112,80)
threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2)
mask = (uncertainty < threshold).float()
loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1 = model1(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
update_ema_variables(model1, model2, args.ema_decay, epoch)
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_1 = model1(inputs_val_1)
outputs_val_2 = model2(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(
print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_URPC.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss, entropy_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=1, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet_urpc', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-URPC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train_unsup = imagefloder_itn(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
kl_distance = nn.KLDivLoss(reduction='none')
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
optimizer1.zero_grad()
pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)
pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)
preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4
variance_aux1 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup1), dim=1, keepdim=True)
exp_variance_aux1 = torch.exp(-variance_aux1)
variance_aux2 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup2), dim=1, keepdim=True)
exp_variance_aux2 = torch.exp(-variance_aux2)
variance_aux3 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup3), dim=1, keepdim=True)
exp_variance_aux3 = torch.exp(-variance_aux3)
variance_aux4 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup4), dim=1, keepdim=True)
exp_variance_aux4 = torch.exp(-variance_aux4)
consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2
consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)
consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2
consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)
consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2
consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)
consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2
consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)
loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup = sup_index['image']
img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val = Variable(data['image'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer1.zero_grad()
outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'URPC')
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_URPC_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-c', '--unsup_weight', default=5, type=float)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet3d_urpc', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-UPRC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))
visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_it(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model1 = model1.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
kl_distance = nn.KLDivLoss(reduction='none')
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model1.train()
train_loss_sup_1 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
optimizer1.zero_grad()
pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)
pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)
pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)
pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)
pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)
preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4
variance_aux1 = torch.sum(kl_distance(torch.log(pred_train_unsup1), preds), dim=1, keepdim=True)
exp_variance_aux1 = torch.exp(-variance_aux1)
variance_aux2 = torch.sum(kl_distance(torch.log(pred_train_unsup2), preds), dim=1, keepdim=True)
exp_variance_aux2 = torch.exp(-variance_aux2)
variance_aux3 = torch.sum(kl_distance(torch.log(pred_train_unsup3), preds), dim=1, keepdim=True)
exp_variance_aux3 = torch.exp(-variance_aux3)
variance_aux4 = torch.sum(kl_distance(torch.log(pred_train_unsup4), preds), dim=1, keepdim=True)
exp_variance_aux4 = torch.exp(-variance_aux4)
consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2
consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)
consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2
consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)
consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2
consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)
consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2
consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)
loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4
loss_train_sup = loss_train_sup1
loss_train_sup.backward()
optimizer1.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)
train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer1.zero_grad()
outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)
val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'URPC', cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_XNet.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models.getnetwork import get_network
import argparse
import time
import os
import numpy as np
from torch.backends import cudnn
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import sys
from config.dataset_config.dataset_cfg import dataset_cfg
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet')
parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice')
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
# trained model save
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
# seg results save
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
# vis
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
if args.input2 == 'image':
input2_mean = 'MEAN'
input2_std = 'STD'
else:
input2_mean = 'MEAN_' + args.input2
input2_std = 'STD_' + args.input2
data_transforms = data_transform_2d()
data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])
dataset_train_unsup = imagefloder_iitnn(
data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark,
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=False,
num_images=None,
)
num_images_unsup = len(dataset_train_unsup)
dataset_train_sup = imagefloder_iitnn(
data_dir=args.path_dataset + '/train_sup_'+args.sup_mark,
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=num_images_unsup,
)
dataset_val = imagefloder_iitnn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
dist.barrier()
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
since = time.time()
count_iter = 0
best_model = model
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = unsup_index['image']
img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))
img_train_unsup_2 = unsup_index['image_2']
img_train_unsup_2 = Variable(img_train_unsup_2.cuda(non_blocking=True))
optimizer.zero_grad()
pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2)
max_train1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = sup_index['image']
img_train_sup_1 = Variable(img_train_sup_1.cuda(non_blocking=True))
img_train_sup_2 = sup_index['image_2']
img_train_sup_2 = Variable(img_train_sup_2.cuda(non_blocking=True))
mask_train_sup = sup_index['mask']
mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))
pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2)
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val'] / 16:
inputs_val_1 = Variable(data['image'].cuda(non_blocking=True))
inputs_val_2 = Variable(data['image_2'].cuda(non_blocking=True))
mask_val = Variable(data['mask'].cuda(non_blocking=True))
name_val = data['ID']
optimizer.zero_grad()
outputs_val1, outputs_val2 = model(inputs_val_1, inputs_val_2)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)
visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_semi_XNet3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_iit
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--input2', default='DB2_H')
parser.add_argument('--sup_mark', default='20')
parser.add_argument('--unsup_mark', default='80')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet3d', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results_1 = path_seg_results + '/pred'
if not os.path.exists(path_seg_results_1) and rank == args.rank_index:
os.mkdir(path_seg_results_1)
if args.vis and rank == args.rank_index:
visdom_env = str('Semi-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_unsup = dataset_iit(
data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,
input1=args.input1,
input2=args.input2,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=False,
num_images=None
)
num_images_unsup = len(dataset_train_unsup.dataset_1)
dataset_train_sup = dataset_iit(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
input2=args.input2,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=num_images_unsup
)
dataset_val = dataset_iit(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)
train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)
dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_model = model
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train_sup'].sampler.set_epoch(epoch)
dataloaders['train_unsup'].sampler.set_epoch(epoch)
model.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
dataset_train_sup = iter(dataloaders['train_sup'])
dataset_train_unsup = iter(dataloaders['train_unsup'])
for i in range(num_batches['train_sup']):
unsup_index = next(dataset_train_unsup)
img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())
img_train_unsup_2 = Variable(unsup_index['image2'][tio.DATA].cuda())
optimizer.zero_grad()
pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2)
max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]
max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]
max_train_unsup1 = max_train_unsup1.long()
max_train_unsup2 = max_train_unsup2.long()
loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train_unsup.backward(retain_graph=True)
torch.cuda.empty_cache()
sup_index = next(dataset_train_sup)
img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())
img_train_sup_2 = Variable(sup_index['image2'][tio.DATA].cuda())
mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())
pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = pred_train_sup1
score_list_train2 = pred_train_sup2
mask_list_train = mask_train_sup
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)
score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)
loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)
loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)
loss_train_sup = loss_train_sup1 + loss_train_sup2
loss_train_sup.backward()
optimizer.step()
torch.cuda.empty_cache()
loss_train = loss_train_unsup + loss_train_sup
train_loss_unsup += loss_train_unsup.item()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_itn
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=4, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-ds', '--deep_supervision', default=False)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='unet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))
visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
# Dataset
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
dataset_train = imagefloder_itn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
dataset_val = imagefloder_itn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize,
sup=True,
num_images=None,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss = 0.0
val_loss = 0.0
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train = Variable(data['image'].cuda())
mask_train = Variable(data['mask'].cuda())
optimizer.zero_grad()
outputs_train = model(inputs_train)
torch.cuda.empty_cache()
if args.deep_supervision:
loss_train = 0
for output_train in outputs_train:
loss_train += criterion(output_train, mask_train)
loss_train /= len(outputs_train)
outputs_train = outputs_train[0]
else:
loss_train = criterion(outputs_train, mask_train)
loss_train.backward()
optimizer.step()
train_loss += loss_train.item()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train = outputs_train
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 4:
score_list_train = torch.cat((score_list_train, outputs_train), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train, score_list_train)
score_list_train = torch.cat(score_gather_list_train, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)
train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val = Variable(data['image'].cuda())
mask_val = Variable(data['mask'].cuda())
name_val = data['ID']
optimizer.zero_grad()
outputs_val = model(inputs_val)
torch.cuda.empty_cache()
if args.deep_supervision:
loss_val = 0
for output_val in outputs_val:
loss_val += criterion(output_val, mask_val)
loss_val /= len(outputs_val)
outputs_val = outputs_val[0]
else:
loss_val = criterion(outputs_val, mask_val)
val_loss += loss_val.item()
if i == 0:
score_list_val = outputs_val
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val = torch.cat((score_list_val, outputs_val), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
torch.cuda.empty_cache()
score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val, score_list_val)
score_list_val = torch.cat(score_gather_list_val, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)
val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)
visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_it
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.005, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='vnet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results+'/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark))
visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_sup = dataset_it(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=None
)
dataset_val = dataset_it(
data_dir=args.path_dataset + '/val',
input1=args.input1,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss = 0.0
val_loss = 0.0
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train = Variable(data['image'][tio.DATA].cuda())
mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_train = model(inputs_train)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train = outputs_train
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train = torch.cat((score_list_train, outputs_train), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
loss_train = criterion(outputs_train, mask_train)
loss_train.backward()
optimizer.step()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train, score_list_train)
score_list_train = torch.cat(score_gather_list_train, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)
train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val = Variable(data['image'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_val = model(inputs_val)
torch.cuda.empty_cache()
if i == 0:
score_list_val = outputs_val
mask_list_val = mask_val
else:
score_list_val = torch.cat((score_list_val, outputs_val), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val = criterion(outputs_val, mask_val)
val_loss += loss_val.item()
torch.cuda.empty_cache()
score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val, score_list_val)
score_list_val = torch.cat(score_gather_list_val, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)
val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_sup(visdom, epoch + 1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_ConResNet.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_ConResNet, print_val_loss_ConResNet, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_ConResNet, visualization_ConResNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_iit_conresnet
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='image')
parser.add_argument('--input2', default='image_res')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.1, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='conresnet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results = path_seg_results + '/pred'
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-ConResNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))
visdom = visdom_initialization_ConResNet(env=visdom_env, port=args.visdom_port)
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_sup = dataset_iit_conresnet(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
input2=args.input2,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=None,
)
dataset_val = dataset_iit_conresnet(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
# Training Strategy
criterion_dice = segmentation_loss('dice', False).cuda()
criterion_ce = segmentation_loss('CE', False).cuda()
criterion_bound = segmentation_loss('bcebound', False, num_classes=cfg['NUM_CLASSES']).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss_seg = 0.0
train_loss_res = 0.0
train_loss = 0.0
val_loss_seg = 0.0
val_loss_res = 0.0
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train_1 = Variable(data['image'][tio.DATA].cuda())
inputs_train_2 = Variable(data['image2'][tio.DATA].cuda())
mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
mask_train_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_train = model(inputs_train_1, inputs_train_2)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train = outputs_train[0]
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train = torch.cat((score_list_train, outputs_train[0]), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
loss_train_seg = criterion_dice(outputs_train[0], mask_train) + criterion_ce(outputs_train[0], mask_train)
loss_train_res = criterion_bound(outputs_train[1], mask_train_2) + 0.5 * (criterion_bound(outputs_train[2], mask_train_2) + criterion_bound(outputs_train[3], mask_train_2))
loss_train = loss_train_seg + loss_train_res
loss_train.backward()
optimizer.step()
train_loss_seg += loss_train_seg.item()
train_loss_res += loss_train_res.item()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train, score_list_train)
score_list_train = torch.cat(score_gather_list_train, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss = print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus)
train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val = Variable(data['image'][tio.DATA].cuda())
inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
mask_val_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_val = model(inputs_val, inputs_val_2)
torch.cuda.empty_cache()
if i == 0:
score_list_val = outputs_val[0]
mask_list_val = mask_val
else:
score_list_val = torch.cat((score_list_val, outputs_val[0]), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_seg = criterion_dice(outputs_val[0], mask_val) + criterion_ce(outputs_val[0], mask_val)
loss_val_res = criterion_bound(outputs_val[1], mask_val_2) + 0.5 * (criterion_bound(outputs_val[2], mask_val_2) + criterion_bound(outputs_val[3], mask_val_2))
val_loss_seg += loss_val_seg.item()
val_loss_res += loss_val_res.item()
torch.cuda.empty_cache()
score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val, score_list_val)
score_list_val = torch.cat(score_gather_list_val, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_seg, val_epoch_loss_res = print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half)
val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_ConResNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_seg, train_epoch_loss_res, train_m_jc, val_epoch_loss_seg, val_epoch_loss_res, val_m_jc)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_XNet.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
if args.input2 == 'image':
input2_mean = 'MEAN'
input2_std = 'STD'
else:
input2_mean = 'MEAN_' + args.input2
input2_std = 'STD_' + args.input2
data_transforms = data_transform_2d()
data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])
dataset_train = imagefloder_iitnn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
dataset_val = imagefloder_iitnn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_model = model
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train_1 = Variable(data['image'].cuda())
inputs_train_2 = Variable(data['image_2'].cuda())
mask_train = Variable(data['mask'].cuda())
optimizer.zero_grad()
outputs_train1, outputs_train2 = model(inputs_train_1, inputs_train_2)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = outputs_train1
score_list_train2 = outputs_train2
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 4:
score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)
score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
max_train1 = torch.max(outputs_train1, dim=1)[1]
max_train2 = torch.max(outputs_train2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_sup1 = criterion(outputs_train1, mask_train)
loss_train_sup2 = criterion(outputs_train2, mask_train)
loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup
loss_train.backward()
optimizer.step()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss_unsup += loss_train_unsup.item()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val = Variable(data['image'].cuda())
inputs_val_wavelet = Variable(data['image_2'].cuda())
mask_val = Variable(data['mask'].cuda())
name_val = data['ID']
optimizer.zero_grad()
outputs_val1, outputs_val2 = model(inputs_val, inputs_val_wavelet)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)
visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_XNet3d.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
import torchio as tio
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_3d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_3d import dataset_iit
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')
parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--sup_mark', default='100', help='100')
parser.add_argument('-b', '--batch_size', default=1, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('--patch_size', default=(96, 96, 80))
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('--queue_length', default=48, type=int)
parser.add_argument('--samples_per_volume_train', default=4, type=int)
parser.add_argument('--samples_per_volume_val', default=8, type=int)
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet3d', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_mask_results = path_seg_results + '/mask'
if not os.path.exists(path_mask_results) and rank == args.rank_index:
os.mkdir(path_mask_results)
path_seg_results_1 = path_seg_results + '/pred'
if not os.path.exists(path_seg_results_1) and rank == args.rank_index:
os.mkdir(path_seg_results_1)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))+str(args.input1)+'-'+str(args.input2)
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
data_transform = data_transform_3d(cfg['NORMALIZE'])
dataset_train_sup = dataset_iit(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
input2=args.input2,
transform_1=data_transform['train'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_train,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
sup=True,
num_images=None
)
dataset_val = dataset_iit(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
transform_1=data_transform['val'],
queue_length=args.queue_length,
samples_per_volume=args.samples_per_volume_val,
patch_size=args.patch_size,
num_workers=8,
shuffle_subjects=False,
shuffle_patches=False,
sup=True,
num_images=None
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_model = model
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train_1 = Variable(data['image'][tio.DATA].cuda())
inputs_train_2 = Variable(data['image2'][tio.DATA].cuda())
mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_train_1, outputs_train_2 = model(inputs_train_1, inputs_train_2)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train_1 = outputs_train_1
score_list_train_2 = outputs_train_2
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 32:
score_list_train_1 = torch.cat((score_list_train_1, outputs_train_1), dim=0)
score_list_train_2 = torch.cat((score_list_train_2, outputs_train_2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
max_train1 = torch.max(outputs_train_1, dim=1)[1]
max_train2 = torch.max(outputs_train_2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_sup1 = criterion(outputs_train_1, mask_train)
loss_train_sup2 = criterion(outputs_train_2, mask_train)
loss_train_unsup = criterion(outputs_train_1, max_train2) + criterion(outputs_train_2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup
loss_train.backward()
optimizer.step()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss_unsup += loss_train_unsup.item()
train_loss += loss_train.item()
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train_1 = [torch.zeros_like(score_list_train_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train_1, score_list_train_1)
score_list_train_1 = torch.cat(score_gather_list_train_1, dim=0)
score_gather_list_train_2 = [torch.zeros_like(score_list_train_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train_2, score_list_train_2)
score_list_train_2 = torch.cat(score_gather_list_train_2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train_1, score_list_train_2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'][tio.DATA].cuda())
inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())
mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())
optimizer.zero_grad()
outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2)
torch.cuda.empty_cache()
if i == 0:
score_list_val_1 = outputs_val_1
score_list_val_2 = outputs_val_2
mask_list_val = mask_val
else:
score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)
score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
loss_val_sup_1 = criterion(outputs_val_1, mask_val)
loss_val_sup_2 = criterion(outputs_val_2, mask_val)
val_loss_sup_1 += loss_val_sup_1.item()
val_loss_sup_2 += loss_val_sup_2.item()
torch.cuda.empty_cache()
score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)
score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)
score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)
score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])
torch.cuda.empty_cache()
if args.vis:
visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_XNet_sb.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')
parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')
parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')
parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--input1', default='L')
parser.add_argument('--input2', default='H')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=2, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('-u', '--unsup_weight', default=5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='xnet_sb', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
print_num_minus = print_num - 2
print_num_half = int(print_num / 2 - 1)
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))
visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)
# Dataset
if args.input1 == 'image':
input1_mean = 'MEAN'
input1_std = 'STD'
else:
input1_mean = 'MEAN_' + args.input1
input1_std = 'STD_' + args.input1
if args.input2 == 'image':
input2_mean = 'MEAN'
input2_std = 'STD'
else:
input2_mean = 'MEAN_' + args.input2
input2_std = 'STD_' + args.input2
data_transforms = data_transform_2d()
data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])
dataset_train = imagefloder_iitnn(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['train'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
dataset_val = imagefloder_iitnn(
data_dir=args.path_dataset + '/val',
input1=args.input1,
input2=args.input2,
data_transform_1=data_transforms['val'],
data_normalize_1=data_normalize_1,
data_normalize_2=data_normalize_2,
sup=True,
num_images=None,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model1 = get_network(args.network, 3, cfg['NUM_CLASSES'])
model2 = get_network(args.network, 1, cfg['NUM_CLASSES'])
model1 = model1.cuda()
model2 = model2.cuda()
model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])
model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])
dist.barrier()
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)
optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)
exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)
# Train & Val
since = time.time()
count_iter = 0
best_model = model1
best_result = 'Result1'
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter - 1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model1.train()
model2.train()
train_loss_sup_1 = 0.0
train_loss_sup_2 = 0.0
train_loss_unsup = 0.0
train_loss = 0.0
val_loss_sup_1 = 0.0
val_loss_sup_2 = 0.0
unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train_1 = Variable(data['image'].cuda())
inputs_train_2 = Variable(data['image_2'].cuda())
mask_train = Variable(data['mask'].cuda())
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_train1 = model1(inputs_train_1)
outputs_train2 = model2(inputs_train_2)
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train1 = outputs_train1
score_list_train2 = outputs_train2
mask_list_train = mask_train
# else:
elif 0 < i <= num_batches['train_sup'] / 4:
score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)
score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
max_train1 = torch.max(outputs_train1, dim=1)[1]
max_train2 = torch.max(outputs_train2, dim=1)[1]
max_train1 = max_train1.long()
max_train2 = max_train2.long()
loss_train_sup1 = criterion(outputs_train1, mask_train)
loss_train_sup2 = criterion(outputs_train2, mask_train)
loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)
loss_train_unsup = loss_train_unsup * unsup_weight
loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup
loss_train.backward()
optimizer1.step()
optimizer2.step()
train_loss_sup_1 += loss_train_sup1.item()
train_loss_sup_2 += loss_train_sup2.item()
train_loss_unsup += loss_train_unsup.item()
train_loss += loss_train.item()
scheduler_warmup1.step()
scheduler_warmup2.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train1, score_list_train1)
score_list_train1 = torch.cat(score_gather_list_train1, dim=0)
score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train2, score_list_train2)
score_list_train2 = torch.cat(score_gather_list_train2, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)
train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
torch.cuda.empty_cache()
with torch.no_grad():
model1.eval()
model2.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_1 = Variable(data['image'].cuda())
inputs_val_2 = Variable(data['image_2'].cuda())
mask_val = Variable(data['mask'].cuda())
name_val = data['ID']
optimizer1.zero_grad()
optimizer2.zero_grad()
outputs_val1 = model1(inputs_val_1)
outputs_val2 = model2(inputs_val_2)
torch.cuda.empty_cache()
if i == 0:
score_list_val1 = outputs_val1
score_list_val2 = outputs_val2
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
loss_val_sup1 = criterion(outputs_val1, mask_val)
loss_val_sup2 = criterion(outputs_val2, mask_val)
val_loss_sup_1 += loss_val_sup1.item()
val_loss_sup_2 += loss_val_sup2.item()
torch.cuda.empty_cache()
score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val1, score_list_val1)
score_list_val1 = torch.cat(score_gather_list_val1, dim=0)
score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val2, score_list_val2)
score_list_val2 = torch.cat(score_gather_list_val2, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)
val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)
best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)
visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)
visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_alnet.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d, data_transform_aerial_lanenet
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_aerial_lanenet
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup')
parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup')
parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--sup_mark', default='100', help='20, 100')
parser.add_argument('-b', '--batch_size', default=32, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='alnet', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))
visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)
# Dataset
data_transforms = data_transform_2d()
data_normalize = data_normalize_2d(cfg['MEAN'], cfg['STD'])
data_normalize_l1 = data_transform_aerial_lanenet(64, 64)
data_normalize_l2 = data_transform_aerial_lanenet(32, 32)
data_normalize_l3 = data_transform_aerial_lanenet(16, 16)
data_normalize_l4 = data_transform_aerial_lanenet(8, 8)
dataset_train = imagefloder_aerial_lanenet(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
data_transform=data_transforms['train'],
data_normalize=data_normalize,
data_normalize_l1=data_normalize_l1,
data_normalize_l2=data_normalize_l2,
data_normalize_l3=data_normalize_l3,
data_normalize_l4=data_normalize_l4
)
dataset_val = imagefloder_aerial_lanenet(
data_dir=args.path_dataset + '/val',
data_transform=data_transforms['val'],
data_normalize=data_normalize,
data_normalize_l1=data_normalize_l1,
data_normalize_l2=data_normalize_l2,
data_normalize_l3=data_normalize_l3,
data_normalize_l4=data_normalize_l4
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss = 0.0
val_loss = 0.0
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train = Variable(data['image'].cuda())
inputs_train_l1 = Variable(data['image_l1'].cuda())
inputs_train_l2 = Variable(data['image_l2'].cuda())
inputs_train_l3 = Variable(data['image_l3'].cuda())
inputs_train_l4 = Variable(data['image_l4'].cuda())
mask_train = Variable(data['mask'].cuda())
optimizer.zero_grad()
outputs_train = model(inputs_train, inputs_train_l1, inputs_train_l2, inputs_train_l3, inputs_train_l4)
torch.cuda.empty_cache()
loss_train = criterion(outputs_train, mask_train)
loss_train.backward()
optimizer.step()
train_loss += loss_train.item()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train = outputs_train
mask_list_train = mask_train
else:
# elif 0 < i <= num_batches['train_sup'] / 16:
score_list_train = torch.cat((score_list_train, outputs_train), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train, score_list_train)
score_list_train = torch.cat(score_gather_list_train, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)
train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val = Variable(data['image'].cuda())
inputs_val_l1 = Variable(data['image_l1'].cuda())
inputs_val_l2 = Variable(data['image_l2'].cuda())
inputs_val_l3 = Variable(data['image_l3'].cuda())
inputs_val_l4 = Variable(data['image_l4'].cuda())
mask_val = Variable(data['mask'].cuda())
name_val = data['ID']
optimizer.zero_grad()
outputs_val = model(inputs_val, inputs_val_l1, inputs_val_l2, inputs_val_l3, inputs_val_l4)
torch.cuda.empty_cache()
loss_val = criterion(outputs_val, mask_val)
val_loss += loss_val.item()
if i == 0:
score_list_val = outputs_val
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val = torch.cat((score_list_val, outputs_val), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
torch.cuda.empty_cache()
score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val, score_list_val)
score_list_val = torch.cat(score_gather_list_val, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)
val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)
visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)
================================================
FILE: train_sup_wds.py
================================================
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg
from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup
from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_wds
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def init_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup')
parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup')
parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI')
parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')
parser.add_argument('--sup_mark', default='100')
parser.add_argument('-b', '--batch_size', default=32, type=int)
parser.add_argument('-e', '--num_epochs', default=200, type=int)
parser.add_argument('-s', '--step_size', default=50, type=int)
parser.add_argument('-l', '--lr', default=0.5, type=float)
parser.add_argument('-g', '--gamma', default=0.5, type=float)
parser.add_argument('--loss', default='dice', type=str)
parser.add_argument('-w', '--warm_up_duration', default=20)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')
parser.add_argument('-i', '--display_iter', default=5, type=int)
parser.add_argument('-n', '--network', default='wds', type=str)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')
parser.add_argument('-v', '--vis', default=True, help='need visualization or not')
parser.add_argument('--visdom_port', default=16672, help='16672')
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rank = torch.distributed.get_rank()
ngpus_per_node = torch.cuda.device_count()
init_seeds(rank + 1)
dataset_name = args.dataset_name
cfg = dataset_cfg(dataset_name)
print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7
print_num_minus = print_num - 2
path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_trained_models) and rank == args.rank_index:
os.mkdir(path_trained_models)
path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)
if not os.path.exists(path_seg_results) and rank == args.rank_index:
os.mkdir(path_seg_results)
if args.vis and rank == args.rank_index:
visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))
visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)
# Dataset
data_transforms = data_transform_2d()
data_normalize_LL = data_normalize_2d(cfg['MEAN_LL'], cfg['STD_LL'])
data_normalize_LH = data_normalize_2d(cfg['MEAN_LH'], cfg['STD_LH'])
data_normalize_HL = data_normalize_2d(cfg['MEAN_HL'], cfg['STD_HL'])
data_normalize_HH = data_normalize_2d(cfg['MEAN_HH'], cfg['STD_HH'])
dataset_train = imagefloder_wds(
data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,
data_transform_1=data_transforms['train'],
data_normalize_LL=data_normalize_LL,
data_normalize_LH=data_normalize_LH,
data_normalize_HL=data_normalize_HL,
data_normalize_HH=data_normalize_HH
)
dataset_val = imagefloder_wds(
data_dir=args.path_dataset + '/val',
data_transform_1=data_transforms['val'],
data_normalize_LL=data_normalize_LL,
data_normalize_LH=data_normalize_LH,
data_normalize_HL=data_normalize_HL,
data_normalize_HH=data_normalize_HH
)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
dataloaders = dict()
dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)
dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)
num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}
# Model
model = get_network(args.network, 1, cfg['NUM_CLASSES'])
model = model.cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
# Training Strategy
criterion = segmentation_loss(args.loss, False).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)
# Train & Val
since = time.time()
count_iter = 0
best_val_eval_list = [0 for i in range(4)]
for epoch in range(args.num_epochs):
count_iter += 1
if (count_iter-1) % args.display_iter == 0:
begin_time = time.time()
dataloaders['train'].sampler.set_epoch(epoch)
model.train()
train_loss = 0.0
val_loss = 0.0
dist.barrier()
for i, data in enumerate(dataloaders['train']):
inputs_train_LL = Variable(data['image_LL'].cuda())
inputs_train_LH = Variable(data['image_LH'].cuda())
inputs_train_HL = Variable(data['image_HL'].cuda())
inputs_train_HH = Variable(data['image_HH'].cuda())
mask_train = Variable(data['mask'].cuda())
optimizer.zero_grad()
outputs_train = model(inputs_train_LL, inputs_train_LH, inputs_train_HL, inputs_train_HH)
torch.cuda.empty_cache()
loss_train = criterion(outputs_train, mask_train)
loss_train.backward()
optimizer.step()
train_loss += loss_train.item()
if count_iter % args.display_iter == 0:
if i == 0:
score_list_train = outputs_train
mask_list_train = mask_train
else:
# elif 0 < i <= num_batches['train_sup'] / 16:
score_list_train = torch.cat((score_list_train, outputs_train), dim=0)
mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)
scheduler_warmup.step()
torch.cuda.empty_cache()
if count_iter % args.display_iter == 0:
score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_train, score_list_train)
score_list_train = torch.cat(score_gather_list_train, dim=0)
mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_train, mask_list_train)
mask_list_train = torch.cat(mask_gather_list_train, dim=0)
if rank == args.rank_index:
torch.cuda.empty_cache()
print('=' * print_num)
print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)
train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
for i, data in enumerate(dataloaders['val']):
# if 0 <= i <= num_batches['val']:
inputs_val_LL = Variable(data['image_LL'].cuda())
inputs_val_LH = Variable(data['image_LH'].cuda())
inputs_val_HL = Variable(data['image_HL'].cuda())
inputs_val_HH = Variable(data['image_HH'].cuda())
mask_val = Variable(data['mask'].cuda())
name_val = data['ID']
optimizer.zero_grad()
outputs_val = model(inputs_val_LH, inputs_val_LH, inputs_val_HL, inputs_val_HH)
torch.cuda.empty_cache()
loss_val = criterion(outputs_val, mask_val)
val_loss += loss_val.item()
if i == 0:
score_list_val = outputs_val
mask_list_val = mask_val
name_list_val = name_val
else:
score_list_val = torch.cat((score_list_val, outputs_val), dim=0)
mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
name_list_val = np.append(name_list_val, name_val, axis=0)
torch.cuda.empty_cache()
score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(score_gather_list_val, score_list_val)
score_list_val = torch.cat(score_gather_list_val, dim=0)
mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(mask_gather_list_val, mask_list_val)
mask_list_val = torch.cat(mask_gather_list_val, dim=0)
name_gather_list_val = [None for _ in range(ngpus_per_node)]
torch.distributed.all_gather_object(name_gather_list_val, name_list_val)
name_list_val = np.concatenate(name_gather_list_val, axis=0)
torch.cuda.empty_cache()
if rank == args.rank_index:
val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)
val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)
best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)
torch.cuda.empty_cache()
if args.vis:
draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)
visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)
visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])
print('-' * print_num)
print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')
torch.cuda.empty_cache()
torch.cuda.empty_cache()
if rank == args.rank_index:
time_elapsed = time.time() - since
m, s = divmod(time_elapsed, 60)
h, m = divmod(m, 60)
print('=' * print_num)
print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')
print('-' * print_num)
print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)
print('=' * print_num)